Applied Deep Learning - Computer Vision for Age and Gender prediction¶

💡 📢 ☑️ remember to read the readme.md file for helpful hints on the best ways to view/navigate this project

If you visualize this notebook on github you will be missing important content

Some charts/diagrams/features are not visible in github. This is standard and well-known behaviour.

Consider viewing the pre-rendered HTML files, or run all notebooks end to end after enabling the feature flags that control long running operations:

If you chose to run this locally, there are some prerequisites:

  • you will need python 3.9
  • you will need to install the dependencies using pip install -r requirements.txt before proceeding.
  • you will need additional system packages cm-super, dvipng for correct rendering of some LaTeX content

Problem Statement (provided by Turing College)¶

Module 4: Deep Learning - Sprint 3: Practical Deep Learning
===========================================================

Age and Gender classification
-----------------------------

Congratulations on reaching your last project. We will try to put into practice the concepts we learned so far.

In this lesson, we will take two fairly simple problems - gender classification and age classification from an up-close image of a person. But instead of making two different models, your task will be to make one model that does both of these tasks. Moreover, you will then analyze the model from the ethical point of view and see what sort of dangers and caveats such models can have.

The exercise today is to train a multi-objective image classifier using data from https://www.kaggle.com/jangedoo/utkface-new. You will train a single model that can predict gender and age.

Find out more about multi-task learning: https://ruder.io/multi-task, https://www.youtube.com/watch?v=UdXfsAr4Gjw

Concepts to explore
-------------------

-   Classification task
-   Convolutional neural network
-   AI ethics and bias
-   Model interpretability

Requirements
------------

-   You should go through the standard cycle of EDA-model-evaluation.
-   Create a single model that returns age and gender in a single pass
-   Analyze model performance
-   Understand, which are the best/worst performing samples.
-   Use LIME for model interpretability with images. Understand what you model

Once you are done with these tasks, evaluate any ethical issues with this model

-   Identify how this model can be biased and check if the results show signs of these issues.
-   Analyze bad predictions. Do you see any patterns in misclassified samples, that can cause ethical concerns?
-   Describe in which scenarios your model can be biased. Propose solutions to mitigate it.
-   Think of a domain, where this model could/could-not be deployed.

Evaluation criteria
-------------------

-   EDA
-   A single end-to-end trainable deep learning model is built
-   Correctly modeled classification/regression
-   Correct selection of loss function(s)
-   Model interpretability tools used and insights made
-   Model aggregate performance
-   Quality of ethical concerns raised
-   Code quality

Imports and initial setup¶

In [1]:
from IPython.display import display, Markdown, clear_output, HTML, IFrame
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import ListedColormap
from matplotlib import gridspec
import itertools
import glob
from tqdm import tqdm
import dill
from datetime import datetime
import numba
from tqdm import tqdm

import imagehash
import numpy as np
import pandas as pd
import seaborn as sns

# from fastai.vision.all import
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
from torchvision import transforms, datasets, models
import torch.nn.functional as F

import lime
from lime import lime_image

from torchmetrics.classification import (
    MulticlassConfusionMatrix,
    MulticlassAccuracy,
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassF1Score,
)
from torchmetrics import ConfusionMatrix

from PIL import Image, ImageDraw, ImageFile
from scipy.stats import chi2_contingency, chisquare, laplace, kstest


from sklearn.metrics import (
    precision_score,
    recall_score,
    f1_score,
    roc_curve,
    auc,
    ConfusionMatrixDisplay,
)

from random import random, seed, shuffle
import logging
import warnings

import os
import shutil
from os import path

from watermark import watermark
In [2]:
from utils import *
from utils import __
loading utils modules... ✅ completed
configuring autoreload... ✅ completed
In [3]:
print(watermark())
print(watermark(packages="torch,torchvision,torchmetrics,numpy,pandas,sklearn,scipy"))
print(watermark(conda=True))
Last updated: 2024-02-29T09:30:23.537801+01:00

Python implementation: CPython
Python version       : 3.9.16
IPython version      : 8.10.0

Compiler    : GCC 11.2.0
OS          : Linux
Release     : 5.15.0-88-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 16
Architecture: 64bit

torch       : 2.1.0
torchvision : 0.16.0
torchmetrics: 1.2.0
numpy       : 1.23.5
pandas      : 2.1.4
sklearn     : 1.2.2
scipy       : 1.9.3

conda environment: py39_lab4

In [4]:
seed(100)
pd.options.display.max_rows = 30
pd.options.display.max_colwidth = 50

util.check("done")
✅

Let's use black to auto-format all our cells so they adhere to PEP8

In [5]:
import lab_black

%reload_ext lab_black
util.patch_nb_black()
# fmt: off
# fmt: on
In [6]:
from sklearn import set_config

set_config(transform_output="pandas")
In [7]:
sns.set_theme(context="notebook", style="whitegrid")
plt.rcParams["axes.grid"] = True

moonstone = "#62b6cb"
moonstone_rgb = util.hex_to_rgb(moonstone)
moonstone_rgb_n = np.array(moonstone_rgb) / 255
In [8]:
logger = util.configure_logging(jupyterlab_level=logging.WARN, file_level=logging.DEBUG)

warnings.filterwarnings("ignore", category=FutureWarning)

# import warnings
# warnings.filterwarnings('error', category=pd.errors.DtypeWarning)

Configure logging¶

Thanks kkalera for this great snippet

In [9]:
# kkalera's logger config
logger = logging.getLogger()
logger.setLevel(logging.INFO)

formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")

file_handler = logging.FileHandler("notebook_logging.log")
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)

stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)

logger.addHandler(file_handler)
logger.addHandler(stream_handler)
In [10]:
def ding(title="Ding!", message="Task completed"):
    """
    this method only works on linux.
    I'm only using it to notify me when a long running process completes
    """
    for i in range(2):
        !notify-send '{title}' '{message}'

Feature Toggles¶

Let's also create a simple feature toggle that we can use to skip expensive operations during notebook work (to save myself some time!)

Set it to true if you want to run absolutely everything. Set to false to skip optional steps/exploratory work.

In [11]:
def run_entire_notebook(filename: str = None, value_only=False):
    run_all_flag = False
    if value_only:
        return run_all_flag

    if not run_all_flag:
        print("skipping optional operation")
        fullpath = f"cached/printouts/{filename}.txt"
        if filename is not None and os.path.exists(fullpath):
            print("==== 🗃️ printing cached output ====")
            with open(fullpath) as f:
                print(f.read())
    return run_all_flag

Fetching Data¶

In [12]:
kaggle_dataset_name = "jangedoo/utkface-new"
db_filename = "UTKFace"

auto_kaggle.download_dataset(
    kaggle_dataset_name,
    db_filename,
    timeout_seconds=3 * 60,
    is_competition=False,
)
__
Kaggle API 1.5.13 - login as 'edualmas'
File [dataset/UTKFace] already exists locally!
No need to re-download data for dataset [jangedoo/utkface-new]

Let's take a quick look at the downloaded dataset to see if we see any interesting patterns

In [13]:
# SOURCE: https://susanqq.github.io/UTKFace/

gender_map = {
    0: "male",
    1: "female",
}

ethnicity_map = {
    0: "white",
    1: "black",
    2: "asian",
    3: "indian",
    4: "other",
}

Dataset Cleaning¶

Some kaggle datasets have some amount of data duplication.

Let's take a quick look and see if we can save valuable disk space (and training time) by removing all duplicated data.

Removing Duplicated files¶

Let's take the easy steps first: understanding the dataset and detecting/deleting any duplicated files

Using disk inspector¶

We see that the folder structure might show some duplication:

  • Nested folder structure
  • Same folder names
  • Same number of files
  • Same total folder size

We are not 100% sure that there is duplication, but we would not be surprised if there was some duplication going on.

Let's explore a bit more in depth and see what we find. Right now, there are a few scenarions that we could encounter:

Ideal scenario:

  1. The entire folder (UTKFace) is duplicated and placed inside the "utkface_aligned_cropped" folder.
  2. same, but with the "crop_part1" folder

If so, we can delete them both!

Plausible (non ideal) scenarios:

  1. The files are not identical duplicates (different hash), but the picture info is, and the only difference is in picture metadata (unlikely).
  2. Only some of the files are identical and we have to check them all individually and decide which ones to keep.
  3. All the files have minor tweaking to the image but we can see that most of them are imperceptible differences and it would be safe to just keep 1 set of the files

Using MELD¶

Let's use meld, a GUI tool for bulk file/folder comparisons:

Let's use a binary file comparison took (meld) that will be able to compare entire folders..

Initially we see that it says "identical contents" for the nested and non nested folders.

We want to make one last test to make sure. If we modify 1 of the files at random, manually, we see that it does detect the change ✅

If we edit its sibling file to have the same content, it no longer detects the change and goes back to saying "identical contents" (thus also proving that it's inspecting the binary content of the file, and not just the timestamp).

This gives us enough confidence to say that the crop_part1 folders are identical.

Let's do the same test on the other subfolder:

The same applies to the other subfolder (UTKFace vs utkface_aligned_cropped/UTKFace).

We're fairly confident that this evidence allows us to delete half the content of the dataset as duplicated (after performing binary file content comparison)

Saving us from processing >= 33,000 files and 120MB in disk space.

Nothing kills disk performance more than accessing thousands of tiny files!

Using low-level sha256 Checksums¶

Let's do one LAST check, just to be extra sure... since making at mistake at this point would result in us incorrectly destroying half the data. A 5 minute check now can save us major problems later.

There seems to be a difference in the inner vs outer "crop" folder.. but we suspect that it's because of the file we tampered with... Let's do one LAST check:

The only diff detected by sha256 is the file we tampered with ✅

We are 100% confident that these folders are definitely duplicates and we can delete them.

Comparing the 2 parts of the dataset¶

We have figured out that the 2 folders (out of 4) were identically duplicated, but we still have not determined what is in the folders that are left.

Removing Identical images (using filehash)¶

In [14]:
!ls -l dataset/crop_part1/ | wc --lines
9781
In [15]:
!ls -l dataset/UTKFace/ | wc --lines
23709

It seems they contain a different number of files each. It could be that one of the folders contains a subset of pics from the other one.

We really need to make sure we eliminate all duplicate images to avoid data leakage and artificially inflating the performance of our model.

A quick check with meld seems to show that one of the folders crop_part1 is a subset of the other folder UTKFace... ALMOST:

It seems that there is, at least, 1 file that is not in UTKFace. Let's quickly automate an easy way to get a single copy of each of the unique files. We don't want to do this manually and risk missing something.

We want to create a dict of dict(filehash, filename), where we will store all the files we encounter. This will guarantee that only a single copy of each picture is kept (due to how collisions are "resolved" in dict keys).

We will then be able to copy those files to a new sanitized folder of known-to-be-unique-pictures:

In [16]:
def list_all_files(dataset_folder: str) -> list[str]:
    return glob.glob(f"{dataset_folder}/*")

Let's calculate a hash for all the files:

In [17]:
if run_entire_notebook():
    !sha512sum dataset/crop_part1/*.jpg > dataset/individual_hashes.txt
    !sha512sum dataset/UTKFace/*.jpg >> dataset/individual_hashes.txt
skipping optional operation
In [18]:
!wc -l dataset/individual_hashes.txt
33488 dataset/individual_hashes.txt
In [19]:
!head -n 3 dataset/individual_hashes.txt
14ac5ab1a9d5dbd6243c82c578b413a0a980ae94b2b49f9d37c20b2eab4ec5222389a7f13e5b1ff968f684dba4f77d5475d1b5ab5e39c9361cfc6d570181df98  dataset/crop_part1/10_0_0_20161220222308131.jpg.chip.jpg
374819e76623a8844bc6e086c4336f411a53cee101687e52988928f63ec4b14e4b9a8f498d202c80feac5168fd900f97eacabdb814ef483e53ed0295143102a5  dataset/crop_part1/10_0_0_20170103200329407.jpg.chip.jpg
96efbbd97e6b37d2e4213dd04928f5bec0698d7d0ff347aee82f5e3e6a5c343e215cc56e2d7995a9da33b618b2d0f9b89fb26304c45ee8db208d55b71aae5097  dataset/crop_part1/10_0_0_20170103200522151.jpg.chip.jpg
In [20]:
hashes = pd.read_fwf(
    "dataset/individual_hashes.txt",
    widths=[128, 2, 1000],
    header=None,
)
hashes.columns = ["filehash", "sep", "filename"]
hashes.drop(columns="sep", inplace=True)
duplicated = hashes["filehash"].value_counts()
duplicated = duplicated[duplicated > 1]
# let's only count the duplicates, for our metrics.
# so 1 = 1 duplicate = 2 files with same content
duplicated = duplicated - 1
duplicated
Out[20]:
filehash
a23fa03989f3f48067fe5060db21924c456f51fd416bb0a92a4a797beee5c43f3d137318a199554ff33f4e0286a5b15beb78d22dddc41d058f058f25c5ee9f3e    5
706ea2e3da0bcaf39421356dddd35db2c4086108001e9eee5bdcec63479ca6a74738955dec6179f74985ca170d6e4a05e7228b4d03aeae709a91a229831c2378    5
0ef50bddb35c616916ed8133064a2be175add0e3e561fff3758bf9c44aef08ad11310ff25bc09ad7016c065dd64ea73fe90d68a8021148ea7f36de79c6f68170    5
8e0c4ee947b63171f37610628f4edad47fb002c0d6083dcf998726e373871ea00d852e9b7077e81c970f7812febcfd6918ef30026734809aa16794eb237c7d4c    5
952c372684eba466c7c178dd30d1346b67570c9ea2836237ef860f4ce928877feda7972f2791ea7e6ec4d326af105c4cb7d96bdcbd9170971f60244eb41463ab    5
                                                                                                                                   ..
b5b8e44204be12c627d4e7ede365eb2599d838d10ae5ca3cd63b6a0a0a8e61e75f84f2f097966098d6499c13bbd5320315270a570402ead4bf6e0e40cefc61d0    1
462c00d7a62686c570735275ef6ebad3a1d96a7454a09b5fbed3ff4b881c8563bc3cc1d5be3af606ae7c59e56e455bb9d12e7992063d32656295c6e293da3080    1
e8377dadc2aa45661dcb4c42001929f5f468ee38e6e43def72a01dcb0380a4cf7d3cda5763c05df5ab38e340027b1be6b4267d1a0927b953da0fdc1eefb1d89c    1
033d3739a24f770cc8330563cf933880d354714f95fc734e2473a0e6e473dccfcb3722bfde881217891cc25a6e58466766e62dd9b427017282f550c26e8f15f5    1
c758e981f4c9207ef3c56d794d5cd1b0fdf24a188eee70198b2f7a13c79ca67ed092f7e4dffb7797a0a2485241536776a7c2c22fdfdf6b2fbb3e91381da638cd    1
Name: count, Length: 9883, dtype: int64
In [21]:
duplicated.sum()
Out[21]:
10170

It seems, out of the 33k files, 10k files are duplicated!

This could be a major source of leakage, contributing to getting a false sense of performance, as well as a drain of resources. We will not keep the duplicated files.

Technical note: The chance of random collisions using sha512 is extremely rare ($1 \over 10^{77}$), so we're not even going to bother checking them manually.

In [22]:
dir_clean_1 = "dataset/clean1_removed_identical_hash/"
os.makedirs(dir_clean_1, exist_ok=True)
In [23]:
hashes["new_filename"] = (
    hashes["filename"]
    .str.replace("dataset/", dir_clean_1)
    .str.replace("UTKFace/", "")
    .str.replace("crop_part1/", "")
)
In [24]:
files_unique_hashes = hashes.groupby("filehash").first()
files_unique_hashes.head()
Out[24]:
filename new_filename
filehash
000c10431583581973efb97a4ab0f8b08850e72d797988b8b8efd44ac78c5c1668ace33dbe7bc87ddff54dc1911034905ad226dea45bf076dfcab08ba8a44162 dataset/UTKFace/17_1_1_20170114030034621.jpg.c... dataset/clean1_removed_identical_hash/17_1_1_2...
000edaff14d1bd3d804af056b37a159f0154453164d9ef0f8414c097292affb8134a809fbd6cb0e3990a026d58e746e873e0bd7451f51e076ef78d292a82b1b8 dataset/UTKFace/35_0_1_20170117121610224.jpg.c... dataset/clean1_removed_identical_hash/35_0_1_2...
001117a1aeb92e6b8d34e19490472b77a4b5751365bed9d43fa0e16d73ef2e28b43760d33a7524e24cc6e5edbfe0a327b2f8169697e6501982e52a45e6e8601d dataset/crop_part1/42_0_0_20170104183950934.jp... dataset/clean1_removed_identical_hash/42_0_0_2...
0018ef8ae5a1e95d1447e0c4a36e1de0362923503490913085840060d6e90ff67c4e51e2a741a0409fa0ae63ea64c73cf561c11d464392d3b12df8a6462eb99a dataset/UTKFace/31_0_0_20170117181923333.jpg.c... dataset/clean1_removed_identical_hash/31_0_0_2...
001e33bc5fa66b3caaae5bda44c41de6b7d91aaf66c2b782b0ef11809ed35aadae3fc85582efa15935f404beb49c189c0eb15d63fc782e30a6b92501dca5a3ef dataset/UTKFace/58_0_1_20170113174947234.jpg.c... dataset/clean1_removed_identical_hash/58_0_1_2...
In [25]:
if run_entire_notebook("excluding_files_identical_hash"):
    pbar = tqdm(files_unique_hashes.iterrows(), desc="removing files with identical hash")

    for index, row in pbar:
        shutil.copy(row.filename, row.new_filename)
skipping optional operation
==== 🗃️ printing cached output ====
selecting unique files: 23318it [00:01, 19289.47it/s]

We're done!

The dataset/unique/ directory contains each unique picture, without duplicates.

In [26]:
files_unique_hashes
Out[26]:
filename new_filename
filehash
000c10431583581973efb97a4ab0f8b08850e72d797988b8b8efd44ac78c5c1668ace33dbe7bc87ddff54dc1911034905ad226dea45bf076dfcab08ba8a44162 dataset/UTKFace/17_1_1_20170114030034621.jpg.c... dataset/clean1_removed_identical_hash/17_1_1_2...
000edaff14d1bd3d804af056b37a159f0154453164d9ef0f8414c097292affb8134a809fbd6cb0e3990a026d58e746e873e0bd7451f51e076ef78d292a82b1b8 dataset/UTKFace/35_0_1_20170117121610224.jpg.c... dataset/clean1_removed_identical_hash/35_0_1_2...
001117a1aeb92e6b8d34e19490472b77a4b5751365bed9d43fa0e16d73ef2e28b43760d33a7524e24cc6e5edbfe0a327b2f8169697e6501982e52a45e6e8601d dataset/crop_part1/42_0_0_20170104183950934.jp... dataset/clean1_removed_identical_hash/42_0_0_2...
0018ef8ae5a1e95d1447e0c4a36e1de0362923503490913085840060d6e90ff67c4e51e2a741a0409fa0ae63ea64c73cf561c11d464392d3b12df8a6462eb99a dataset/UTKFace/31_0_0_20170117181923333.jpg.c... dataset/clean1_removed_identical_hash/31_0_0_2...
001e33bc5fa66b3caaae5bda44c41de6b7d91aaf66c2b782b0ef11809ed35aadae3fc85582efa15935f404beb49c189c0eb15d63fc782e30a6b92501dca5a3ef dataset/UTKFace/58_0_1_20170113174947234.jpg.c... dataset/clean1_removed_identical_hash/58_0_1_2...
... ... ...
fff42adb969d15c96001a1e5bb2f5cfd71ee30c6fadad5a77d71d3ae74e5e7ee9b449bb388c69439cc9a78bf6a7c429a49e6f8b676a1979d865ff6d0046aba64 dataset/crop_part1/61_0_3_20170109141653583.jp... dataset/clean1_removed_identical_hash/61_0_3_2...
fff4dd576d231e871559d5326323537d89d642d683d54d3b42a9b4e3cf8526d75e175cc2423903e2f951a7d4fdd1ee7b7eae9816c0da92948044e6b5c8b03ed9 dataset/crop_part1/54_0_0_20170104213004356.jp... dataset/clean1_removed_identical_hash/54_0_0_2...
fff818ed2fc6d24e6d1560a33104df80f031f0508ae7ce161ff0f7ce98670ff99f8cad821db7a981a8eff56700e6b43dd52def6df2c873d870067a89823238f0 dataset/crop_part1/61_1_0_20170110122324992.jp... dataset/clean1_removed_identical_hash/61_1_0_2...
fff8f71b41b9dd31352b39ae16d6811791b7bf810f10326d527c6aa0806caeabbc710277a0bd5c4d95b60eccd36893dd3b06b07dc487b570df2df0c23096b902 dataset/UTKFace/20_0_0_20170117140842001.jpg.c... dataset/clean1_removed_identical_hash/20_0_0_2...
fff9e8f3cde1ebb72b715362774665d98e20a6345bcb46576cbdb1af86413a89b8cbb812d7fc33454e9ef0039d5af757f3a08e16c86eae345774ae64b7049405 dataset/crop_part1/54_1_1_20170110120122138.jp... dataset/clean1_removed_identical_hash/54_1_1_2...

23318 rows × 2 columns

Removing extremely similar images (using phash)¶

Now that we have deleted duplicated (identical) files, the easy part of the work is done.

But we suspect that there might be "similar but not identical" images.

These will be harder to detect because they could be "similar to humans, but not identical". We will use a few algorithms and select the ones we find most optimal for this part of the cleaning.

Performance considerations and Time complexity analysis¶

Check out the sandbox folder, which contains a few notebooks where we tried and tested several algorithms to detect and identify similar images:

  • from cosine similarity,
  • to using externally available projects: similar-images-remover and modifying them improve their performance using multiprocessing
  • to trying out various methods to perform dimensionality reduction (PCA, ...)
  • to using libraries to hash images based on appearance
  • and others...

In this notebook, we just kept the system that proved best (in terms of results and performance). Some of the algorithms were suboptimal and had a time complexity of $O(n^2)$ which was not ideal when you have 20k images: 20k * 20k = 400 million comparisons.

We had to find other approaches to optimizing the search to be not exponential (linear, ideally).

In [27]:
image_similarity_analysis = (
    pd.DataFrame(files_unique_hashes["new_filename"])
    .reset_index(drop=True)
    .rename(columns={"new_filename": "filename"})
)
image_similarity_analysis
Out[27]:
filename
0 dataset/clean1_removed_identical_hash/17_1_1_2...
1 dataset/clean1_removed_identical_hash/35_0_1_2...
2 dataset/clean1_removed_identical_hash/42_0_0_2...
3 dataset/clean1_removed_identical_hash/31_0_0_2...
4 dataset/clean1_removed_identical_hash/58_0_1_2...
... ...
23313 dataset/clean1_removed_identical_hash/61_0_3_2...
23314 dataset/clean1_removed_identical_hash/54_0_0_2...
23315 dataset/clean1_removed_identical_hash/61_1_0_2...
23316 dataset/clean1_removed_identical_hash/20_0_0_2...
23317 dataset/clean1_removed_identical_hash/54_1_1_2...

23318 rows × 1 columns

In [28]:
def hash_with_length(length_bytes=8):
    def hash_imagefile(filename: str) -> str:
        """
        calculates perceptual hash for each file to be able to compare "similar images" using brightness score for each cluster of pixels
        cluster size is a grid and is configured using length_bytes.
        higher length_bytes results in a grid with more cells (n^2) and slightly longer computation during comparison, but no real benefit
        """
        return imagehash.phash(Image.open(filename), hash_size=length_bytes)

    return hash_imagefile


def hex_to_int(hash_value: str) -> int:
    return int(str(hash_value), base=16)
In [29]:
@cached_dataframe()
def similarity_analysis_with_hash():
    image_similarity_analysis["hash_str_8"] = image_similarity_analysis["filename"].apply(
        hash_with_length(8)
    )
    image_similarity_analysis["hash_str_8"] = image_similarity_analysis[
        "hash_str_8"
    ].astype("str")
    return image_similarity_analysis


similarity_analysis_with_hash()
Loading from cache [./cached/df/similarity_analysis_with_hash.parquet]
Out[29]:
filename hash_str_8
0 dataset/clean1_removed_identical_hash/17_1_1_2... 95f5aa5681c2b43e
1 dataset/clean1_removed_identical_hash/35_0_1_2... f8dfc300c7c07d32
2 dataset/clean1_removed_identical_hash/42_0_0_2... 91c54cd3c7c3326b
3 dataset/clean1_removed_identical_hash/31_0_0_2... 95854a96e3db622d
4 dataset/clean1_removed_identical_hash/58_0_1_2... 9dc51e9580d70b76
... ... ...
23313 dataset/clean1_removed_identical_hash/61_0_3_2... 95da8549d28b726d
23314 dataset/clean1_removed_identical_hash/54_0_0_2... 91ad5ed593a26835
23315 dataset/clean1_removed_identical_hash/61_1_0_2... d0944996c79325df
23316 dataset/clean1_removed_identical_hash/20_0_0_2... c6c119c796d7313c
23317 dataset/clean1_removed_identical_hash/54_1_1_2... dcd05b8863e24d3b

23318 rows × 2 columns

Let's see if there are similar-ish pictures (pictures that have an identical Perceptual Hash).

In [30]:
@run
@cached_chart()
def similar_percept_hash():
    similar = similarity_analysis_with_hash()
    counts = similar["hash_str_8"].value_counts()
    counts = counts[counts > 1]
    sns.countplot(x=counts[counts > 1], order=counts.value_counts().index)
    return plt.gcf()
Loading from cache [./cached/charts/similar_percept_hash.png]

Removing Somewhat similar images (phash with thresholding)¶

  • Observations

    • We can see that over +200 of them have identical matches for the hash.
    • We can also see that some have more than 2 matches.
    • This does not even include pictures that have "very similar" matches, just identical ones.
  • Outcome

    • We will try to find "similar-enough" images as well. Image pairs that do not have identical hashes, but that are close enough.
    • This will require extensive computations since we have to effectively compare all images to all other ones. O(N^2) time complexity.
    • Other methods (KNN, etc..), might look like they can automate this for us, but they will also suffer similar performance penalties.
      • KNN does not scale well to $20k * 20k = 400,000,000$ comparisons.
      • Even worse if we consider each point is represented in a $8*8 = 64$ dimensions hyperplane
      • KNN and other options are discarded. Likely to perform worse, and still overkill.
    • We will try to use some optimizations to speed up this brute force method:
      • numba-just-in-time
      • parallelization
      • look-ahead-only comparisons (instead of comparing all pics against all ) 23k choose 2: $23k\choose2$
    • Some other approaches seemed overkill.

Overall the code below can compute all the comparisons in less than 45 seconds on a cheap laptop CPU, and in 20 seconds on a desktop CPU

Let's remember how Combinations work:

${C_k(n)} = {n\choose k} = {{n!} \over {k!(n-k)!}}$

  • for $n = 23318$
  • and $k = 2$

${C_2(23318)} = {23318\choose 2} = {{23318!} \over {2!(23318-2)!}} = 271852903$

  • Total number of calcs = 271,852,903 instead of 543,729,124.
  • Which gives us a ratio of 0,4999785 => resulting in 2x faster (as expected)

Based on what I can see in here https://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html I think we can treat each bit as an individual piece of info and calculate hamming distance by just comparing each bit individually! :)

We will use @numba justintime decorators to try to speed up the expensive computations

In [31]:
@numba.jit(nopython=True)
def hamming_distance(hash1, hash2):
    return np.sum(hash1 != hash2)


@numba.jit(nopython=True, parallel=True)
def calculate_distances(i, hashes, threshold):
    is_duplicate = np.zeros(len(hashes), dtype=np.bool_)
    for j in numba.prange(i + 1, len(hashes)):
        hamming_dist = hamming_distance(hashes[i], hashes[j])
        if hamming_dist <= threshold:
            is_duplicate[j] = True
    return np.where(is_duplicate)[0]


def mark_duplicates(df, hash_col, threshold):
    """
    lists image pairs that are detected as duplicate

    """
    hashes = np.array(
        [list(map(int, bin(int(str(h), 16))[2:].zfill(64))) for h in df[hash_col]]
    )

    duplicates_dict = {}
    for i in tqdm(range(len(df))):
        duplicate_indices = calculate_distances(i, hashes, threshold)
        if len(duplicate_indices) > 0:
            duplicates_dict[i] = duplicate_indices

    return duplicates_dict
  • Observations:
    • threshold of 1 was a good proof of concept, let's try more tolerance
    • threshold of 2 was still too narrow
    • threshold of 12 was too lose, general shadows/lights but very different pics.
    • threshold of 6 showed pictures of people in similar poses, but clearly not the same pic
    • threshold of 3 showed good matches, let's try 4
    • threshold of 4 showed 2x as many as threshold of 3
    • threshold of 5 was still good, 800 sets of images (1 keys + (1 or more) identical images)
    • threshold of 6 was 2x than at 5 (1800 sets of images), and back to pretty bad.
      • We found some unhinged matches, we noted some particular ids to make sure they still appear at lower thresholds
  • Conclusion:
    • for this dataset, threshold of 5 seems ideal
In [32]:
# Just to remember the performance of this algorithm for future benchmarks
# 23k images = 543 million comparisons = 19 seconds!
# Numba and JIT compilation rocks!
@cached_with_pickle(force=run_entire_notebook("mark_duplicates_phash"))
def image_similarity_adj_matrix():
    similar = similarity_analysis_with_hash()
    duplicates_dict = mark_duplicates(similar, hash_col="hash_str_8", threshold=5)
    return duplicates_dict


image_similarity_adj = image_similarity_adj_matrix()
skipping optional operation
==== 🗃️ printing cached output ====
"100%|██████████| 23318/23318 [00:19<00:00, 1167.02it/s]

Loading from cache [./cached/pickle/image_similarity_adj_matrix.pickle]
  • Observations:
    • We see that most of the similar images have different tags: same image, different target age, etc... so, since it's impossible for us to determine which ones are the right labels and which ones are not, we will consider all flag imaged to be compromised, and we will drop them from our dataset.
    • despite the summary counts showing ~700 images duplicated, we must remember that this needs further postprocessing (due to how we decided to configure our code to speed it up: using only-look-forward-search).
  • Outcome:
    • we will consider "duplicated" pics if their IDs appear in the keys, or in the values.
    • we will drop all "duplicated" pics from our dataset
In [33]:
# just a sample, to visualize what this adjacency matrix looks like:
for idx in list(image_similarity_adj.keys())[80:110]:
    print(idx, image_similarity_adj[idx])
1169 [18638]
1198 [5356]
1226 [16114]
1231 [12054 17146]
1232 [13587 18548]
1236 [16791]
1247 [14691]
1248 [21699]
1254 [15256]
1256 [3180]
1261 [15627]
1281 [15290]
1312 [ 6256 20831]
1323 [7851]
1331 [13056 18001 18889]
1337 [2977]
1342 [7509 9209]
1356 [ 5219 11216]
1376 [11845]
1385 [13283]
1393 [9222]
1394 [7352]
1403 [2285 2777]
1410 [7755]
1438 [4272]
1447 [20833]
1448 [22896]
1469 [ 9073 19880 21486]
1470 [16086]
1509 [14763 20468]
In [34]:
@cached_dataframe(force=run_entire_notebook())
def duplicates_similarity_df():
    similar = similarity_analysis_with_hash()
    duplicates_dict = image_similarity_adj_matrix()

    all_keys = list(duplicates_dict.keys())
    all_values = np.concatenate(list(duplicates_dict.values())).tolist()
    all_duplicated_pics_ids = set(all_keys + all_values)

    similar["is_similar"] = False
    similar.loc[list(all_duplicated_pics_ids), "is_similar"] = True
    return similar


image_with_similarity = duplicates_similarity_df()
image_with_similarity
skipping optional operation
Loading from cache [./cached/df/duplicates_similarity_df.parquet]
Out[34]:
filename hash_str_8 is_similar
0 dataset/clean1_removed_identical_hash/17_1_1_2... 95f5aa5681c2b43e True
1 dataset/clean1_removed_identical_hash/35_0_1_2... f8dfc300c7c07d32 False
2 dataset/clean1_removed_identical_hash/42_0_0_2... 91c54cd3c7c3326b False
3 dataset/clean1_removed_identical_hash/31_0_0_2... 95854a96e3db622d False
4 dataset/clean1_removed_identical_hash/58_0_1_2... 9dc51e9580d70b76 False
... ... ... ...
23313 dataset/clean1_removed_identical_hash/61_0_3_2... 95da8549d28b726d False
23314 dataset/clean1_removed_identical_hash/54_0_0_2... 91ad5ed593a26835 False
23315 dataset/clean1_removed_identical_hash/61_1_0_2... d0944996c79325df True
23316 dataset/clean1_removed_identical_hash/20_0_0_2... c6c119c796d7313c True
23317 dataset/clean1_removed_identical_hash/54_1_1_2... dcd05b8863e24d3b False

23318 rows × 3 columns

In [35]:
image_with_similarity.is_similar.value_counts()
Out[35]:
is_similar
False    21848
True      1470
Name: count, dtype: int64
In [36]:
def plot_pics(ids: list[int]):
    f, ax = plt.subplots(1, len(ids), figsize=(len(ids) * 5, 5))
    for i in range(len(ids)):
        filename = str(similar.loc[ids[i]]["filename"])
        print(filename)
        ax[i].imshow(Image.open(filename))
    plt.tight_layout()
    return plt.gcf()

Just for future reference, these are some of the images used to determine the threshold to use to detect "similar enough" images.

We have extracted the IDs of these images while doing manual tuning of the thresholds.

Observations:

  • You will notice that a higher threshold results in too many false positives.
  • Below threshold 5 there are not false positives (for the results we inspected manually)
In [37]:
# Threshold of 12
@run
@cached_chart()
def similar_threshold_12_a():
    print(
        """
dataset/clean1_removed_identical_hash/15_1_4_20170103230530985.jpg.chip.jpg
dataset/clean1_removed_identical_hash/37_1_0_20170109134008515.jpg.chip.jpg
dataset/clean1_removed_identical_hash/23_1_0_20170116221811019.jpg.chip.jpg
dataset/clean1_removed_identical_hash/30_1_1_20170116012131745.jpg.chip.jpg
dataset/clean1_removed_identical_hash/17_1_0_20170109214021426.jpg.chip.jpg"""
    )
    return plot_pics([93, 12244, 17302, 17699, 19950])
Loading from cache [./cached/charts/similar_threshold_12_a.png]

There are some similarities, but they are clearly different people

In [38]:
# Threshold of 6
@run
@cached_chart()
def similar_threshold_6_a():
    print(
        """
dataset/clean1_removed_identical_hash/32_1_0_20170117154910644.jpg.chip.jpg
dataset/clean1_removed_identical_hash/22_1_3_20170119153416689.jpg.chip.jpg
dataset/clean1_removed_identical_hash/32_1_0_20170117134809503.jpg.chip.jpg"""
    )
    return plot_pics([7834, 8970, 19856])
Loading from cache [./cached/charts/similar_threshold_6_a.png]
In [39]:
# Threshold of 6 - nope!
@run
@cached_chart()
def similar_threshold_6_b():
    print(
        """
dataset/clean1_removed_identical_hash/16_1_0_20170109213504335.jpg.chip.jpg
dataset/clean1_removed_identical_hash/27_1_3_20170104223505487.jpg.chip.jpg
dataset/clean1_removed_identical_hash/24_1_2_20170104234618170.jpg.chip.jpg
dataset/clean1_removed_identical_hash/26_1_3_20170117154940189.jpg.chip.jpg
dataset/clean1_removed_identical_hash/25_1_3_20170117152030871.jpg.chip.jpg
dataset/clean1_removed_identical_hash/26_1_3_20170117174028333.jpg.chip.jpg
dataset/clean1_removed_identical_hash/26_1_3_20170104235421282.jpg.chip.jpg
dataset/clean1_removed_identical_hash/25_1_2_20170104021040316.jpg.chip.jpg
dataset/clean1_removed_identical_hash/25_1_3_20170117152019467.jpg.chip.jpg
"""
    )
    return plot_pics([5122, 7395, 11127, 11444, 15119, 16152, 18056, 19818, 22147])
Loading from cache [./cached/charts/similar_threshold_6_b.png]
In [40]:
# Threshold of 5
# Good! the 2 clear matches from the previous round still appear at threshold 5!
# nice!
@run
@cached_chart()
def similar_threshold_5_a():
    print(
        """
dataset/clean1_removed_identical_hash/25_1_3_20170117152030871.jpg.chip.jpg
dataset/clean1_removed_identical_hash/25_1_3_20170117152019467.jpg.chip.jpg
"""
    )
    return plot_pics([15119, 22147])
Loading from cache [./cached/charts/similar_threshold_5_a.png]
In [41]:
# Threshold of 5
@run
@cached_chart()
def similar_threshold_5_b():
    print(
        """
dataset/clean1_removed_identical_hash/72_0_0_20170111201853033.jpg.chip.jpg
dataset/clean1_removed_identical_hash/65_0_0_20170120225159632.jpg.chip.jpg
dataset/clean1_removed_identical_hash/75_0_0_20170111205238382.jpg.chip.jpg
"""
    )
    return plot_pics([1044, 5567, 21425])
Loading from cache [./cached/charts/similar_threshold_5_b.png]
In [42]:
# Threshold of 5
@run
@cached_chart()
def similar_threshold_5_c():
    print(
        """
dataset/clean1_removed_identical_hash/30_0_1_20170113141654362.jpg.chip.jpg
dataset/clean1_removed_identical_hash/28_0_1_20170103225933161.jpg.chip.jpg
dataset/clean1_removed_identical_hash/32_0_1_20170113001102379.jpg.chip.jpg
"""
    )
    return plot_pics([15758, 16371, 19337])
Loading from cache [./cached/charts/similar_threshold_5_c.png]
In [43]:
# Threshold of 4
@run
@cached_chart()
def similar_threshold_4_a():
    print(
        """
dataset/clean1_removed_identical_hash/23_1_2_20170116173016687.jpg.chip.jpg
dataset/clean1_removed_identical_hash/23_1_2_20170116173145383.jpg.chip.jpg
dataset/clean1_removed_identical_hash/24_0_2_20170116164749805.jpg.chip.jpg
"""
    )
    return plot_pics([49, 515, 3572])
Loading from cache [./cached/charts/similar_threshold_4_a.png]

Threshold of 3

Looking good!

In [44]:
# threshold of 3
@run
@cached_chart()
def similar_threshold_3_a():
    print(
        """
dataset/clean1_removed_identical_hash/1_0_0_20170110213328641.jpg.chip.jpg
dataset/clean1_removed_identical_hash/1_0_0_20161219200139603.jpg.chip.jpg
dataset/clean1_removed_identical_hash/1_0_0_20161219204741557.jpg.chip.jpg
"""
    )
    return plot_pics([985, 15770, 20737])
Loading from cache [./cached/charts/similar_threshold_3_a.png]
In [45]:
# threshold of 3
@run
@cached_chart()
def similar_threshold_3_b():
    print(
        """
dataset/clean1_removed_identical_hash/1_0_0_20170110213328641.jpg.chip.jpg
dataset/clean1_removed_identical_hash/1_0_0_20161219200139603.jpg.chip.jpg
dataset/clean1_removed_identical_hash/1_0_0_20161219204741557.jpg.chip.jpg
"""
    )
    return plot_pics([11415, 11752, 19133])
Loading from cache [./cached/charts/similar_threshold_3_b.png]

One last safety check to make sure that, despite being marked as similar, they do in fact have different hashes

In [46]:
@run
@cached_chart()
def similar_threshold_3_f():
    print(
        """
dataset/clean1_removed_identical_hash/35_0_0_20170117150935786.jpg.chip.jpg
dataset/clean1_removed_identical_hash/35_0_0_20170117170519707.jpg.chip.jpg
dataset/clean1_removed_identical_hash/28_0_0_20170117180626585.jpg.chip.jpg
"""
    )
    return plot_pics([17241, 18228, 21854])
Loading from cache [./cached/charts/similar_threshold_3_f.png]
In [47]:
!sha256sum dataset/clean1_removed_identical_hash/35_0_0_20170117150935786.jpg.chip.jpg
!sha256sum dataset/clean1_removed_identical_hash/35_0_0_20170117170519707.jpg.chip.jpg
!sha256sum dataset/clean1_removed_identical_hash/28_0_0_20170117180626585.jpg.chip.jpg
1e5fad3db6fe0f7f172d9f5f358ea91f5a3dc7f6904f5b0761c591580a75d412  dataset/clean1_removed_identical_hash/35_0_0_20170117150935786.jpg.chip.jpg
01761854e5833bc963def5bfe0bcb8cc69050c608a6bf0582ecda5a946397674  dataset/clean1_removed_identical_hash/35_0_0_20170117170519707.jpg.chip.jpg
e88ff0ceca238030b763cdc585be1e9e119eb22028be59ab3fb835c243db48de  dataset/clean1_removed_identical_hash/28_0_0_20170117180626585.jpg.chip.jpg

Observations:

  • we can see that for these 3 images, even though they appear very similar (and the phash confirms it), the file hash is different, which is why these sets were not detected during the intial filehash scan, nor during the removal of "extremely similar images"
In [48]:
@run
@cached_chart()
def similar_images_perceptual_hash():
    return sns.countplot(similar, x="is_similar")
Loading from cache [./cached/charts/similar_images_perceptual_hash.png]

With this new technique, we have removed an additional 5% of the images which were clearly simiar/identical and would have resulted in some type of data leakage if any of the duplicates would have ended up in our test split.

In [49]:
image_with_similarity
Out[49]:
filename hash_str_8 is_similar
0 dataset/clean1_removed_identical_hash/17_1_1_2... 95f5aa5681c2b43e True
1 dataset/clean1_removed_identical_hash/35_0_1_2... f8dfc300c7c07d32 False
2 dataset/clean1_removed_identical_hash/42_0_0_2... 91c54cd3c7c3326b False
3 dataset/clean1_removed_identical_hash/31_0_0_2... 95854a96e3db622d False
4 dataset/clean1_removed_identical_hash/58_0_1_2... 9dc51e9580d70b76 False
... ... ... ...
23313 dataset/clean1_removed_identical_hash/61_0_3_2... 95da8549d28b726d False
23314 dataset/clean1_removed_identical_hash/54_0_0_2... 91ad5ed593a26835 False
23315 dataset/clean1_removed_identical_hash/61_1_0_2... d0944996c79325df True
23316 dataset/clean1_removed_identical_hash/20_0_0_2... c6c119c796d7313c True
23317 dataset/clean1_removed_identical_hash/54_1_1_2... dcd05b8863e24d3b False

23318 rows × 3 columns

In [50]:
image_with_similarity["is_similar"].value_counts()
Out[50]:
is_similar
False    21848
True      1470
Name: count, dtype: int64
In [51]:
without_similar_images = image_with_similarity[~image_with_similarity["is_similar"]]
without_similar_images
Out[51]:
filename hash_str_8 is_similar
1 dataset/clean1_removed_identical_hash/35_0_1_2... f8dfc300c7c07d32 False
2 dataset/clean1_removed_identical_hash/42_0_0_2... 91c54cd3c7c3326b False
3 dataset/clean1_removed_identical_hash/31_0_0_2... 95854a96e3db622d False
4 dataset/clean1_removed_identical_hash/58_0_1_2... 9dc51e9580d70b76 False
5 dataset/clean1_removed_identical_hash/9_0_0_20... 95207e37e017691f False
... ... ... ...
23311 dataset/clean1_removed_identical_hash/1_0_2_20... 818c1e27de536cd9 False
23312 dataset/clean1_removed_identical_hash/42_0_0_2... b904d69792d7286d False
23313 dataset/clean1_removed_identical_hash/61_0_3_2... 95da8549d28b726d False
23314 dataset/clean1_removed_identical_hash/54_0_0_2... 91ad5ed593a26835 False
23317 dataset/clean1_removed_identical_hash/54_1_1_2... dcd05b8863e24d3b False

21848 rows × 3 columns

Let's check folder size before deleting the files:

In [52]:
!ls dataset/clean1_removed_identical_hash/ | wc -l
23318
In [53]:
dir_clean_2 = "dataset/clean2_removed_similar_images/"
In [54]:
if run_entire_notebook():
    source_dir = dir_clean_1
    target_dir = dir_clean_2
    os.makedirs(target_dir, exist_ok=True)
    for dissimilar_filename in without_similar_images["filename"]:
        source = dissimilar_filename
        target = source.replace(source_dir, target_dir)
        shutil.copy(source, target)
skipping optional operation
In [55]:
def does_file_actually_exist(filename) -> bool:
    return os.path.exists(filename)


file_exists = pd.DataFrame(without_similar_images["filename"])
file_exists["exists"] = file_exists["filename"].map(does_file_actually_exist)
file_exists.head()
Out[55]:
filename exists
1 dataset/clean1_removed_identical_hash/35_0_1_2... True
2 dataset/clean1_removed_identical_hash/42_0_0_2... True
3 dataset/clean1_removed_identical_hash/31_0_0_2... True
4 dataset/clean1_removed_identical_hash/58_0_1_2... True
5 dataset/clean1_removed_identical_hash/9_0_0_20... True
In [56]:
file_exists.exists.value_counts()
Out[56]:
exists
True    21848
Name: count, dtype: int64
In [57]:
!ls dataset/clean1_removed_identical_hash/ | wc -l
23318
In [58]:
!ls dataset/clean2_removed_similar_images/ | wc -l
21848
  • Observations:
    • we went from 23318 files to 21848 files (1470 files skipped, as expected ✅)
In [59]:
without_similar_images.head()
Out[59]:
filename hash_str_8 is_similar
1 dataset/clean1_removed_identical_hash/35_0_1_2... f8dfc300c7c07d32 False
2 dataset/clean1_removed_identical_hash/42_0_0_2... 91c54cd3c7c3326b False
3 dataset/clean1_removed_identical_hash/31_0_0_2... 95854a96e3db622d False
4 dataset/clean1_removed_identical_hash/58_0_1_2... 9dc51e9580d70b76 False
5 dataset/clean1_removed_identical_hash/9_0_0_20... 95207e37e017691f False
In [60]:
without_similar_images = without_similar_images.copy()
without_similar_images["filename"] = without_similar_images.copy()[
    "filename"
].str.replace(dir_clean_1, dir_clean_2)

Splitting the dataset¶

We wanna have our dataset split into non-overlapping subsets so that we can test how our model performs on unseen data.

In [61]:
def parse_labels_from_filename(dir_name_prefix: str = dir_clean_2):
    """
    categorize each of the files and derive labels based on the
    portion part of the filename
    """

    def _enrich(row):
        filename: str = row["filename"]
        filename = filename.replace(dir_name_prefix, "")
        tokens = filename.split("_")
        if len(tokens) != 4:
            logger.warning("unable to parse filename %s", filename)
            return None

        row["age"] = int(tokens[0])
        row["gender"] = gender_map[int(tokens[1])]
        row["ethnicity"] = ethnicity_map[int(tokens[2])]
        return row

    return _enrich
In [62]:
def tag_all_pics(df: pd.DataFrame) -> pd.DataFrame:
    """
    tags the dataframe using the filename for age, gender and ethnicity
    """
    pics = pd.DataFrame(df["filename"]).reset_index(drop=True)
    pics.columns = ["filename"]
    pics = pics.apply(parse_labels_from_filename(dir_clean_2), axis=1).dropna()
    pics["age"] = pics["age"].astype(int)
    return pics


@cached_dataframe()
def all_pics_tagged():
    return tag_all_pics(without_similar_images)
In [63]:
pics = all_pics_tagged()
pics.head(10)
Loading from cache [./cached/df/all_pics_tagged.parquet]
Out[63]:
filename age gender ethnicity
0 dataset/clean2_removed_similar_images/35_0_1_2... 35 male black
1 dataset/clean2_removed_similar_images/42_0_0_2... 42 male white
2 dataset/clean2_removed_similar_images/31_0_0_2... 31 male white
3 dataset/clean2_removed_similar_images/58_0_1_2... 58 male black
4 dataset/clean2_removed_similar_images/9_0_0_20... 9 male white
5 dataset/clean2_removed_similar_images/37_0_0_2... 37 male white
6 dataset/clean2_removed_similar_images/20_1_0_2... 20 female white
7 dataset/clean2_removed_similar_images/38_1_1_2... 38 female black
8 dataset/clean2_removed_similar_images/28_1_3_2... 28 female indian
9 dataset/clean2_removed_similar_images/26_1_2_2... 26 female asian
  • Observations
    • We lost 3 images because the filename is invalid and does not follow the standard syntax (it's missing 1 of the labels)
  • Outcomes:
    • We drop those files and they wont be included

Note how the first/default mapping (0, 0) corresponds to ("white", "male").

The project only requires classifying based on gender and age, but we will also keep ethnicity for some parts of the analysis, as it might have impact.

In [64]:
@cached_dataframes()
def pics_splits():
    splits = split_utils.split_dataset(
        pics,
        target_cols=["age", "gender"],
        stratify_labels=False,
        split_sizes={"train": 0.7, "val": 0.15, "test": 0.15},
    )
    return {
        "train_X": splits["train"][0],
        "train_y": splits["train"][1],
        "val_X": splits["val"][0],
        "val_y": splits["val"][1],
        "test_X": splits["test"][0],
        "test_y": splits["test"][1],
    }
In [65]:
splits = pics_splits()
splits.keys()
Loading from cache [./cached/df_dict/pics_splits.h5]
Out[65]:
dict_keys(['test_X', 'test_y', 'train_X', 'train_y', 'val_X', 'val_y'])

We will keep the ethnicity, for future analysis for bias, but it's will not be fed to the model during training. This column is just here for convenience:

In [66]:
splits["train_X"].head()
Out[66]:
filename ethnicity
1594 dataset/clean2_removed_similar_images/54_0_0_2... white
2202 dataset/clean2_removed_similar_images/14_1_0_2... white
4384 dataset/clean2_removed_similar_images/23_0_0_2... white
701 dataset/clean2_removed_similar_images/17_0_0_2... white
16274 dataset/clean2_removed_similar_images/32_1_0_2... white
In [67]:
splits["train_y"].head()
Out[67]:
age gender
1594 54 male
2202 14 female
4384 23 male
701 17 male
16274 32 female

Storing splits separately¶

In [68]:
os.makedirs("dataset/splits", exist_ok=True)
os.makedirs("dataset/splits/train", exist_ok=True)
os.makedirs("dataset/splits/val", exist_ok=True)
os.makedirs("dataset/splits/test", exist_ok=True)
In [69]:
if run_entire_notebook("splitting_into_folders"):
    splits = pics_splits()
    for dataset in ["train", "val", "test"]:
        target_folder = f"dataset/splits/{dataset}/"
        files = pd.DataFrame(splits[f"{dataset}_X"]["filename"])
        files["target"] = files["filename"].str.replace(dir_clean_2, target_folder)
        print("*" * 20, dataset, "*" * 20)
        display(files.head())

        for index, row in files.iterrows():
            shutil.copy(row.filename, row.target)
skipping optional operation
==== 🗃️ printing cached output ====
	filename 	target
1594 	dataset/clean2_removed_similar_images/54_0_0_2... 	dataset/splits/train/54_0_0_20170117190252594....
2202 	dataset/clean2_removed_similar_images/14_1_0_2... 	dataset/splits/train/14_1_0_20170109203638205....
4384 	dataset/clean2_removed_similar_images/23_0_0_2... 	dataset/splits/train/23_0_0_20170114034609023....
701 	dataset/clean2_removed_similar_images/17_0_0_2... 	dataset/splits/train/17_0_0_20170105183607439....
16274 	dataset/clean2_removed_similar_images/32_1_0_2... 	dataset/splits/train/32_1_0_20170103182408417....

******************** val ********************

	filename 	target
17631 	dataset/clean2_removed_similar_images/61_0_0_2... 	dataset/splits/val/61_0_0_20170117174613406.jp...
14808 	dataset/clean2_removed_similar_images/22_1_1_2... 	dataset/splits/val/22_1_1_20170114033301951.jp...
2040 	dataset/clean2_removed_similar_images/26_1_1_2... 	dataset/splits/val/26_1_1_20170116222929223.jp...
16488 	dataset/clean2_removed_similar_images/61_0_0_2... 	dataset/splits/val/61_0_0_20170111222237144.jp...
8732 	dataset/clean2_removed_similar_images/58_0_0_2... 	dataset/splits/val/58_0_0_20170113142246036.jp...

******************** test ********************

	filename 	target
15975 	dataset/clean2_removed_similar_images/27_1_0_2... 	dataset/splits/test/27_1_0_20170117120616194.j...
4956 	dataset/clean2_removed_similar_images/32_0_0_2... 	dataset/splits/test/32_0_0_20170117140353209.j...
11260 	dataset/clean2_removed_similar_images/68_1_0_2... 	dataset/splits/test/68_1_0_20170113210319664.j...
8461 	dataset/clean2_removed_similar_images/42_0_0_2... 	dataset/splits/test/42_0_0_20170109012239137.j...
11413 	dataset/clean2_removed_similar_images/26_0_3_2... 	dataset/splits/test/26_0_3_20170104230323233.j...

Checking integrity of copy operation¶

In [70]:
def verify_copy_integrity(splitname: str):
    """
    checks that the number of files in the folder split
    matches the number of files in the dataframe,
    to make sure that the copy was correct and
    no files were lost or accidentally included
    """
    files_in_dir = !ls -l dataset/splits/$splitname/ | wc --lines
    files_in_dir = int(files_in_dir[0]) - 1  # gotta skip the header line
    files_in_df_split = pics_splits()[f"{splitname}_X"].shape[0]

    print(files_in_df_split, " == ", files_in_dir)
    assert files_in_df_split == files_in_dir
    util.check(files_in_df_split == files_in_dir)
In [71]:
verify_copy_integrity("train")
Loading from cache [./cached/df_dict/pics_splits.h5]
15291  ==  15291
✅
In [72]:
verify_copy_integrity("val")
Loading from cache [./cached/df_dict/pics_splits.h5]
3277  ==  3277
✅
In [73]:
verify_copy_integrity("test")
Loading from cache [./cached/df_dict/pics_splits.h5]
3277  ==  3277
✅
  • Observations:
    • splits into dataset/splits/* was correct for all 3 splits of data.
In [74]:
def comparison_across_datasplits(col_name: str):
    tr = splits["train_X"].join(splits["train_y"])
    v = splits["val_X"].join(splits["val_y"])
    tst = splits["test_X"].join(splits["test_y"])

    f, ax = plt.subplots(1, 3, figsize=(15, 8))
    if col_name == "age":
        sns.histplot(data=tr, binwidth=5, y=col_name, ax=ax[0], color=moonstone)
        sns.histplot(data=v, binwidth=5, y=col_name, ax=ax[1], color=moonstone)
        sns.histplot(data=tst, binwidth=5, y=col_name, ax=ax[2], color=moonstone)
    else:
        order = gender_map.values() if col_name == "gender" else ethnicity_map.values()
        print(order)
        sns.countplot(data=tr, y=col_name, ax=ax[0], color=moonstone, order=order)
        sns.countplot(data=v, y=col_name, ax=ax[1], color=moonstone, order=order)
        sns.countplot(data=tst, y=col_name, ax=ax[2], color=moonstone, order=order)
    plt.suptitle(f"{col_name} distribution across data splits")
    ax[0].set_title("train split")
    ax[1].set_title("val split")
    ax[2].set_title("test split")
    plt.tight_layout()
    return f
In [75]:
@run
@cached_chart()
def split_age_comparison():
    return comparison_across_datasplits("age")
Loading from cache [./cached/charts/split_age_comparison.png]
In [76]:
@run
@cached_chart()
def split_gender_comparison():
    return comparison_across_datasplits("gender")
Loading from cache [./cached/charts/split_gender_comparison.png]
In [77]:
@run
@cached_chart()
def split_ethnicity_comparison():
    return comparison_across_datasplits("ethnicity")
Loading from cache [./cached/charts/split_ethnicity_comparison.png]

Despite not having used hard stratification, the splits seem close enough to each other to be representative enough across the labels we care about (age, gender)

Exploratory Data Analysis¶

Now that we have ensured:

  • that the duplicates are detected and removed
  • that the data is split properly and in a repeatable manner (stored to disk, to ensure consistency across kernel/system reboots)
  • and that the splits are representative of each other

We're ready to take a look at the data. We will use the train split for EDA.

From now on, we will use the pre-split datasets/pictures from inside the dataset/splits/{split_name} folders

In [78]:
def load_dataset(split_name: str) -> pd.DataFrame:
    path = f"dataset/splits/{split_name}"
    df = pd.DataFrame(list_all_files(path), columns=["filename"])
    return df.apply(parse_labels_from_filename(path + "/"), axis=1).dropna()
In [79]:
@cached_dataframe()
def cached_train_df():
    return load_dataset("train")


@cached_dataframe()
def cached_val_df():
    return load_dataset("val")


@cached_dataframe()
def cached_test_df():
    return load_dataset("test")


df_train = cached_train_df()
df_val = cached_val_df()
df_test = cached_test_df()
Loading from cache [./cached/df/cached_train_df.parquet]
Loading from cache [./cached/df/cached_val_df.parquet]
Loading from cache [./cached/df/cached_test_df.parquet]
In [80]:
overview = df_train
overview.head()
Out[80]:
filename age gender ethnicity
0 dataset/splits/train/26_1_4_20170117154131789.... 26 female other
1 dataset/splits/train/2_1_4_20161221203029673.j... 2 female other
2 dataset/splits/train/30_1_0_20170109001620649.... 30 female white
3 dataset/splits/train/25_0_0_20170120221436173.... 25 male white
4 dataset/splits/train/10_0_4_20170103202338152.... 10 male other
In [81]:
@run
@cached_chart()
def dataset_labels():
    f, ax = plt.subplots(1, 4, figsize=(15, 4))

    sns.histplot(data=overview, binwidth=1, x="age", ax=ax[0], color=moonstone)
    sns.histplot(data=overview, binwidth=10, x="age", ax=ax[1], color=moonstone)
    sns.histplot(data=overview, x="gender", ax=ax[2], color=moonstone)
    sns.histplot(data=overview, x="ethnicity", ax=ax[3], color=moonstone)

    ax[0].set_title("age distribution (1yr group)")
    ax[1].set_title("age distribution (decade)")
    ax[2].set_title("gender distribution")
    ax[3].set_title("ethnicity distribution")

    plt.tight_layout()
    return f
Loading from cache [./cached/charts/dataset_labels.png]
In [82]:
@run
@cached_chart()
def population_breakdown():
    ethnicities = overview["ethnicity"].unique()

    fig, axs = plt.subplots(2, len(ethnicities), figsize=(3 * len(ethnicities), 13))

    for i, ethnicity in enumerate(ethnicity_map.values()):
        data = overview[overview["ethnicity"] == ethnicity]
        male_df = data[data["gender"] == "male"]
        female_df = data[data["gender"] == "female"]

        sns.histplot(
            data=male_df,
            binwidth=5,
            y="age",
            ax=axs[0, i],
            color="grey",
            multiple="dodge",
            label="male",
        )
        sns.histplot(
            data=female_df,
            binwidth=5,
            y="age",
            ax=axs[0, i],
            color="lightgrey",
            multiple="dodge",
            label="female",
        )

        sns.histplot(
            data=male_df,
            binwidth=1,
            y="age",
            ax=axs[1, i],
            color="grey",
            multiple="dodge",
            label="male",
        )
        sns.histplot(
            data=female_df,
            binwidth=1,
            y="age",
            ax=axs[1, i],
            color="lightgrey",
            multiple="dodge",
            label="female",
        )

        axs[0, i].legend()
        axs[1, i].legend()

        axs[0, i].set_title(f"{ethnicity} - 5 yr buckets")
        axs[1, i].set_title(f"{ethnicity}")
        axs[0, i].set_xlim(0, 550)
        axs[1, i].set_xlim(0, 275)

    plt.tight_layout()
    return fig
Loading from cache [./cached/charts/population_breakdown.png]

Even though the previous "overall aggregates" seemed to present a balanced dataset, when we break it down by gender/age it's clear that each dataset has peculiarities.

A few things that jump out:

  • The overrepresentation of white people
  • The overrepresentation of male pics in most slices, except in the "asian", which seems to rebalance genders back down to 50/50%
    • This might make classification difficult despite "male/female" being balanced if we just look at aggregates.
  • Some particular classes are extremely overrepresented:
    • The very large amount of < 5yr, asian, male (compared to other baby groups). Might be mislabeled?
    • The large spike across 26 year olds (both male and female splits). This raises questions around how was this data sampled/collected?
    • There are almost no babies, children or teenagers (between 2yrs and 20yrs), except for ethnicity=white

Browsing our dataset¶

Let's build some utility methods to slice and browse our dataset, using any of the 3 dimensions we have: age, gender, ethnicity

In [83]:
def population_filter(
    population: pd.DataFrame,
    ages: list[int] = None,
    genders: list[str] = None,
    ethnicities: list[str] = None,
) -> pd.DataFrame:
    """
    Retrieves a few samples from the population, which match the criteria specified.
    """
    data = population.copy()

    if ages:
        if not isinstance(ages, list):
            ages = [ages]
        data = data[data["age"].isin(ages)]

    if genders:
        if not isinstance(genders, list):
            genders = [genders]
        data = data[data["gender"].isin(genders)]

    if ethnicities:
        if not isinstance(ethnicities, list):
            ethnicities = [ethnicities]
        data = data[data["ethnicity"].isin(ethnicities)]

    return data

Now we can easily inspect slices of the dataframe and see how few samples we have of any dice of data.

In [84]:
population_filter(overview, ages=[13, 11], genders="female", ethnicities="indian")
Out[84]:
filename age gender ethnicity
875 dataset/splits/train/13_1_3_20170109213029072.... 13 female indian
7066 dataset/splits/train/11_1_3_20170104223632543.... 11 female indian
13577 dataset/splits/train/13_1_3_20170117181350659.... 13 female indian
In [85]:
def population_sample(
    population: pd.DataFrame,
    ages: list[int] = None,
    genders: list[str] = None,
    ethnicities: list[str] = None,
) -> plt.Figure:
    """
    renders pictures from the population that match the specified criteria (based on the pic labels)
    """
    sample = population_filter(
        population, ages=ages, genders=genders, ethnicities=ethnicities
    )
    sample_pics = sample.sample(min(100, len(sample)))["filename"]

    cols = 10
    rows = int(np.ceil(len(sample_pics) / cols))

    f, ax = plt.subplots(rows, cols, figsize=(20, (2.8 * rows)))

    for row, col in itertools.product(np.arange(rows), np.arange(cols)):
        if (row * cols + col) < len(sample_pics):
            filename = sample_pics.iloc[row * cols + col]
            image = Image.open(filename)
            ax[row, col].imshow(image)
            title = filename.split("/")[-1:][0]
            title_labels = title[:6]
            ts = title[15:-13]
            ax[row, col].set_title(title_labels + "..." + ts)
        ax[row, col].axis("off")

    plt.suptitle(f"{genders = }, {ages = }, {ethnicities = }")
    plt.tight_layout()
    plt.show()
    return f

We can also easily visualize pics for any arbitrary subset:

In [86]:
@run
@cached_chart(extension="jpg")
def pop_sample_black_male():
    return population_sample(overview, ages=[50, 51], genders="male", ethnicities="black")
Loading from cache [./cached/charts/pop_sample_black_male.jpg]
In [87]:
@run
@cached_chart(extension="jpg")
def pop_sample_white_female():
    return population_sample(
        overview, ages=[50, 51, 52], genders="female", ethnicities="white"
    )
Loading from cache [./cached/charts/pop_sample_white_female.jpg]

⚠️ For the avid reviewers, you will notice that some pics might be mislabeled.

For exampke, Gene Wilder's picture, appears labeled as "white, woman, 50 years of age".. but that specific picture is a frame from the movie "Willy Wonka & the Chocolate Factory".

  • Age: Gene Wilder (born 1933) was 38 years old when he played the role of Willy Wonka in the 1971 movie "Willy Wonka & the Chocolate Factory".
  • Gender: Gene Wilder was assigned male at birth and I cannot find any reference to him coming out as a trans woman at any point in his career.
  • Gender Expression of the Character: The character that Gene is playing (Willy Wonka) is also portrayed and identified as a male in all the literature I can find, using he/him pronouns to refer to him throughout all reference material found online.
  • Age references of the character: The character that Gene is playing (Willy Wonka) is said to be 47 years old in the movie.

Even if we assumed that the person depicted is not Gene Wilder, but his character, both the gender and the age are still incorrect.

The fact that this image is misclassified on both dimensions raises concerns around the quality of the data this dataset contains.

It is well known that Willy Wonka is not a 50 year old woman, but an immortal meme, in this day and age.

In [88]:
@run
@cached_chart(extension="jpg")
def pop_sample_other_male_10_13():
    return population_sample(
        overview, ages=[10, 11, 12, 13, 14], genders="male", ethnicities="other"
    )
Loading from cache [./cached/charts/pop_sample_other_male_10_13.jpg]
  • Observations:
    • We can see that some slices of the dataset are particularly sparse. This is ALL (100% of the samples) that we have for 5 entire years for the age=young gender=male ethnicity="other". The model better be really good at learning, because it will not get any other help beyond this. It is unlikely that any data augmentation techniques will be able to significantly increase the training datapoints for this demographic.
  • Conclusions:
    • We will need to closely inspect the model's performance around these (and other) underrepresented demographics.

Classifying images using computer vision¶

In Module 4 - Sprint 1, we used PyTorch Lightning to classify pictures of mushrooms. It was a challenge similar to this one, but there were some things that were not ideal.

Framework of choice - FastAI¶

Despite Pytorch Lightning being more lightweight than plain Pytorch, it was still quite verbose and boilerplatey. I did not enjoy the experience... and I'm someone who really appreciates having knobs and dials to tune and configure everything...

Reducing boilerplate to allow us to focus on high-level architecture customization¶

For this project, I'd like to try out and learn/get to know a different library/framework. Ideally one that is more lightweight and that allows us to easily and quickly try different structures for our NN.

FastAI seems like an ideal candidate. I heard it's more concise and seems to have a cleaner wrapper API that tucks away a lot of the boilerplate code/steps. I'm looking forward to this!

In terms of the approach, I expect this to be quite similar to the previous project (find a pretrained model, customize it to have the needed output layer,... ) but the trick will be to select a loss function that allows us to train that last FC layer.

Model Architecture¶

The requirements talk about the need to compile and train a model that can predict in a single pass.

I consider this to be a poor idea. Normally I'd prefer to have 2 simpler models that can be composed/combined and more importantly, trained separately. But I understand that this is a good requirement to make sure we can practice in-depth customization of models.

Some other approaches that could be done:

  1. The "Keep it simple"
    • We could train 2 simpler models: one for age, one for gender.
    • Does not meet the "single pass" requirement.
    • ❌ We will completely skip this
  2. The "Iteration 1": The model we have chosen for this project as iteration one.
    • ✅ Train one model with 2 output layers, each trained individually:
    • Iteration 1.1: The same, but have age done as classification (buckets of 5 years/10 years) instead of regression.
    • Iteration 1.2: The same, but have age using buckets of different ranges based on business needs (eg. "underage": <= 18, "young": 19 to 30, adult: 31 to 50, etc...). If we had more context or needs from the business we could better tune this to find the best way to achieve ideal classification performance.
  3. The "Iteration 2: Training both parts independently (in 2 different models) and then merging them for predictions (might be overkill and might not improve performance)
    • This would require making sure that both model are using the exact same pretrained model as backend.
    • Part of the tricky part of this architecture is how we specify the loss function. We want our model to learn when it's doing things wrong because it got the gender wrong (incorrect classification) vs when it got the age wrong (by a tiny amount) vs when it totally messed up the age (by a large amount). I suspect that training 2 independent models as output FC layers from a pretrained image model, would allow us to find the best tuning for each of the 2 tasks, and then we could merge those two output layers into the final encapsulation
    • ❌ Due to time constraints this was not completed and got removed from this project.

Model Structure¶

An architecture that could be constructed is this:

Most of the low-level operations (flatten, etc...) are handled by, and managed via the framework.

Loss Function¶

We will likely want to penalize different errors in different ways.

I thought of using an orthogonal vectorial loss function so that the loss was a 2 dimensional vector in space. This could allow the loss to individually tune the error in each of the 2 outputs independently, but was discouraged from doing so after talking to a few STLs from the course. The underlying logic was that, given enough training, having a single loss that aggregates both invidual losses using a simple addition (+) would also work.

I decided to start with that, as a starting point (and to listen to the old YAGNI adage)

In [89]:
from fastai.callback.core import Callback
In [90]:
from torch.utils.tensorboard import SummaryWriter
2024-02-29 09:30:28.505076: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-02-29 09:30:28.996098: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-29 09:30:28.996125: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-29 09:30:29.087464: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-29 09:30:29.251982: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-02-29 09:30:30.073575: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
In [91]:
from fastai.vision.all import (
    Path,
    TransformBlock,
    DataBlock,
    ImageBlock,
    FuncSplitter,
    get_image_files,
    Resize,
    aug_transforms,
    Module,
    resnet34,
    resnet50,
    resnet101,
    ResNet34_Weights,
    ResNet50_Weights,
    ResNet101_Weights,
    create_body,
    create_head,
    Learner,
    FloatTensor,
    load_learner,
)
In [92]:
pretrained_configs = {
    "small": {
        "model": resnet34,
        "weights": ResNet34_Weights.IMAGENET1K_V1,
        "num_features": 512,
    },
    "medium": {
        "model": resnet50,
        "weights": ResNet50_Weights.IMAGENET1K_V1,
        "num_features": 2048,
    },
    "large": {
        "model": resnet101,
        "weights": ResNet101_Weights.IMAGENET1K_V1,
        "num_features": 2048,
    },
}

model_size = "small"
pretrained_config = pretrained_configs[model_size]
In [93]:
path = Path("dataset/splits")
In [94]:
def get_labels_age_gender_ethnicity(fname):
    labels = fname.name.split("_")
    # we need float so calculations
    # for age are correct!
    age = float(labels[0])
    gender = int(labels[1])
    ethnicity = int(labels[2])
    return age, gender, ethnicity


def get_labels_age_gender(fname):
    labels = fname.name.split("_")
    age = float(labels[0])
    gender = int(labels[1])
    return age, gender
In [95]:
class AgeGenderBlock(TransformBlock):
    def __init__(
        self,
        type_tfms=None,
        item_tfms=None,
        batch_tfms=None,
        dl_type=None,
        dls_kwargs=None,
    ):
        return super().__init__(
            type_tfms=type_tfms,
            item_tfms=item_tfms,
            batch_tfms=batch_tfms,
            dl_type=dl_type,
            dls_kwargs=dls_kwargs,
        )
In [96]:
def dataset_splitter(filename):
    return filename.parent.name in ["val"]


dblock = DataBlock(
    blocks=(ImageBlock, AgeGenderBlock()),
    splitter=FuncSplitter(dataset_splitter),
    get_items=get_image_files,
    get_y=get_labels_age_gender,
    item_tfms=Resize(460),
    batch_tfms=aug_transforms(size=224, min_scale=0.75),
)
In [97]:
age_losses = []
gender_losses = []
total_losses = []
In [98]:
def weighted_loss(
    age_normalization_factor: float,
    gender_normalization_factor: float,
    age_to_gender_weight_coef: float,
):
    """
    age_normalization_factor: how to normalize the loss for age
    gender_normalization_factor: how to normalize the loss for gender
    age_to_gender_weight_coef: how to distribute the weight for the losses, normally 0.5
    """

    def shared_loss(pred, targ) -> float:
        age_pred, gender_pred = pred
        age_targ, gender_targ = targ

        # mse seems to be too sensitive to outliers
        # age_loss = F.huber_loss(age_pred.squeeze(), age_targ.float())
        age_loss = F.mse_loss(age_pred.squeeze(), age_targ.float())
        gender_loss = F.cross_entropy(gender_pred, gender_targ.long())

        age_loss = age_loss * age_normalization_factor
        gender_loss = gender_loss * gender_normalization_factor

        # standard formula could be, B×MSE+(1−B)×BCE so we only need 1 param!
        w = age_to_gender_weight_coef
        loss = (w * age_loss) + (1 - w) * gender_loss

        learn.age_loss = age_loss.item()
        learn.gender_loss = gender_loss.item()
        return loss

    return shared_loss
In [99]:
class AgeGenderLossLogger(Callback):
    def __init__(self, writer):
        self.writer = writer
        self.train_iter = 0
        self.valid_iter = 0

    def after_loss(self):
        learn = self.learn
        add_scalar = self.writer.add_scalar
        if self.training:
            i = self.train_iter
            add_scalar("Loss/age", learn.age_loss, i)
            add_scalar("Loss/gender", learn.gender_loss, i)
            add_scalar("Loss/total", learn.loss, i)
            self.train_iter += 1
        else:
            i = self.valid_iter
            add_scalar("Valid_Loss/age", learn.age_loss, i)
            add_scalar("Valid_Loss/gender", learn.gender_loss, i)
            add_scalar("Valid_Loss/total", learn.loss, i)
            self.valid_iter += 1

    def after_fit(self):
        self.writer.close()
In [100]:
dls = dblock.dataloaders(path, num_workers=4)
In [101]:
class AgeGenderModel(Module):
    def __init__(self, encoder, n_age_classes, n_gender_classes):
        self.encoder = create_body(encoder(weights=pretrained_config["weights"]))
        self.age_head = create_head(pretrained_config["num_features"], n_age_classes)
        self.gender_head = create_head(
            pretrained_config["num_features"], n_gender_classes
        )

    def forward(self, x):
        x = self.encoder(x)
        age = self.age_head(x)
        gender = self.gender_head(x)
        return age, gender

It seems that the unweighted age loss is much larger than the gender loss (obviously!)

  • age_loss = tensor(1830.1670 ...
  • gender_loss = tensor(0.6796 ...
2024-02-06 17:41:25,186 - root - INFO - w = 0.5
2024-02-06 17:41:25,251 - root - INFO - age_loss = tensor(1830.1670, device='cuda:0', grad_fn=<MseLossBackward0>)
2024-02-06 17:41:25,252 - root - INFO - gender_loss = tensor(0.6796, device='cuda:0', grad_fn=<NllLossBackward0>)
2024-02-06 17:41:25,253 - root - INFO - loss = tensor(915.4233, device='cuda:0', grad_fn=<AddBackward0>)
2024-02-06 17:41:25,464 - root - INFO - w = 0.5
2024-02-06 17:41:25,528 - root - INFO - age_loss = tensor(1113.5537, device='cuda:0', grad_fn=<MseLossBackward0>)
2024-02-06 17:41:25,529 - root - INFO - gender_loss = tensor(0.4847, device='cuda:0', grad_fn=<NllLossBackward0>)
2024-02-06 17:41:25,529 - root - INFO - loss = tensor(557.0192, device='cuda:0', grad_fn=<AddBackward0>)

Let's try to find a weight that results in the weighted gender loss to be on the same order of magnitude as the age loss.

It seems that using a weight of $1 \over 2000$ we achieve a good balance:

  • (w * age_loss) = tensor(0.6909 ...
  • (1 - w) * gender_loss = tensor(0.6667, ...
2024-02-06 17:43:59,496 - root - INFO - w = 0.0005
2024-02-06 17:43:59,563 - root - INFO - age_loss = tensor(1381.7511, device='cuda:0', grad_fn=<MseLossBackward0>)
2024-02-06 17:43:59,563 - root - INFO - gender_loss = tensor(0.6670, device='cuda:0', grad_fn=<NllLossBackward0>)
2024-02-06 17:43:59,564 - root - INFO - (w * age_loss) = tensor(0.6909, device='cuda:0', grad_fn=<MulBackward0>)
2024-02-06 17:43:59,565 - root - INFO - (1 - w) * gender_loss = tensor(0.6667, device='cuda:0', grad_fn=<MulBackward0>)
2024-02-06 17:43:59,565 - root - INFO - loss = tensor(1.3576, device='cuda:0', grad_fn=<AddBackward0>)
2024-02-06 17:43:59,779 - root - INFO - w = 0.0005
2024-02-06 17:43:59,841 - root - INFO - age_loss = tensor(1527.9042, device='cuda:0', grad_fn=<MseLossBackward0>)
2024-02-06 17:43:59,842 - root - INFO - gender_loss = tensor(0.5323, device='cuda:0', grad_fn=<NllLossBackward0>)
2024-02-06 17:43:59,843 - root - INFO - (w * age_loss) = tensor(0.7640, device='cuda:0', grad_fn=<MulBackward0>)
2024-02-06 17:43:59,843 - root - INFO - (1 - w) * gender_loss = tensor(0.5320, device='cuda:0', grad_fn=<MulBackward0>)
2024-02-06 17:43:59,844 - root - INFO - loss = tensor(1.2960, device='cuda:0', grad_fn=<AddBackward0>)
In [102]:
def new_learner():
    writer = SummaryWriter()
    model = AgeGenderModel(
        pretrained_config["model"], n_age_classes=1, n_gender_classes=2
    )

    learn = Learner(
        dls,
        model,
        loss_func=weighted_loss(
            age_normalization_factor=1 / 10,
            gender_normalization_factor=1 / 0.8,
            age_to_gender_weight_coef=1 / 2,
        ),
        cbs=[AgeGenderLossLogger(writer)],
    )
    return learn
In [103]:
learn = new_learner()
In [363]:
if run_entire_notebook("training_model"):
    epochs = 30
    learn = learn.to_fp16()
    learn.fit_one_cycle(epochs, lr_max=0.005)
    model = learn.model
    timestamp_str = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    exportfilename = f"models/learner_{timestamp_str}_{model_size}_{epochs}epochs.pkl"
    last_learner = f"models/learner_chosen_good.pkl"
    learn.export(exportfilename, pickle_module=dill)
    learn.export(
        last_learner, pickle_module=dill
    )  # last trained model overwrite for convenience
else:
    print("\n-----\nloading model from disk")
    learn = load_learner("models/learner_chosen_good.pkl", pickle_module=dill, cpu=False)
    print("done ✅")
skipping optional operation
==== 🗃️ printing cached output ====
epoch 	train_loss 	valid_loss 	time
0 	1.786554 	1.695351 	00:55
1 	1.461248 	1.307861 	00:53
2 	0.647388 	0.517804 	00:54
3 	0.563197 	0.525605 	00:54
4 	0.546874 	0.516621 	00:54
5 	0.515017 	0.508244 	00:54
6 	0.519758 	0.462976 	00:54
7 	0.505805 	0.607207 	00:54
8 	0.486558 	0.532631 	00:55
9 	0.579306 	0.621597 	00:56
10 	0.477896 	0.498944 	00:56
11 	0.532490 	0.478929 	00:56
12 	0.463207 	0.431897 	00:54
13 	0.464652 	0.457235 	00:56
14 	0.452019 	0.424648 	00:55
15 	0.426468 	0.458743 	00:57
16 	0.425189 	0.399749 	00:56
17 	0.422736 	0.448454 	00:55
18 	0.407325 	0.400459 	00:55
19 	0.394068 	0.391915 	00:54
20 	0.391277 	0.412320 	00:53
21 	0.392562 	0.392210 	00:53
22 	0.397992 	0.381142 	00:53
23 	0.375167 	0.405484 	00:53
24 	0.367188 	0.378330 	00:53
25 	0.372667 	0.350337 	00:53
26 	0.376210 	0.370057 	00:53
27 	0.359483 	0.369580 	00:53
28 	0.362384 	0.348189 	00:53
29 	0.346523 	0.348245 	00:53
30 	0.338588 	0.361895 	00:53


-----
loading model from disk
done ✅

Training performance using mixed precision training¶

Something important to note. Just by enabling mixed precision training using .to_fp16(), we have been able to consistently speed up training by an impressive factor:

  • Without Mixed Precision Training: 2m 12' per batch
  • With Mixed Precision Training: 1m 22' per batch

This represents a boost in speed of $82 / 132 >= 60% $ faster!

All of this, almost for free, without sacrificing any precision, since this mixed training can benefit from using smaller data types when needed, but uses full sizes (doubles) if the precision requires it.

Model file size¶

Let's check the size to see how massive it is:

In [105]:
!ls -hs models/learner_chosen_good.pkl
87M models/learner_chosen_good.pkl

Resource usage during training¶

While our model is training, let's check our system resources, settings and config to see how much we're utilizing the GPU:

(py39_lab4) edu@desk:~$ nvidia-smi 
Tue Feb  6 18:11:46 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3060        Off | 00000000:01:00.0  On |                  N/A |
| 46%   67C    P2             134W / 170W |  11957MiB / 12288MiB |     96%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1278      G   /usr/lib/xorg/Xorg                          277MiB |
|    0   N/A  N/A      1887      G   /opt/teamviewer/tv_bin/TeamViewer             2MiB |
|    0   N/A  N/A      1930      G   cinnamon                                     50MiB |
|    0   N/A  N/A     27807      G   ...GI0Y2VkNTI2Zg%3D%3D&browser=firefox       35MiB |
|    0   N/A  N/A     50469      G   /usr/lib/firefox/firefox                    162MiB |
|    0   N/A  N/A     52900      G   /usr/lib/firefox/firefox                      7MiB |
|    0   N/A  N/A     54109      C   ...anaconda3/envs/py39_lab4/bin/python    11390MiB |
+---------------------------------------------------------------------------------------+
(py39_lab4) edu@desk:~$ 

We can see that the configuration of the model/workers/batch is able to utilize almost the entirety of the GPU memory (11.1GB out of 12GB).

We did not have to configure anything manually for this to work so well out of the box!

FastAI is so much nicer and enjoyable than using pytorch lightning and having to tweak and configure low level settings just to avoid crashing our GPU.

It's nice seeing a framework that takes care of the low level details with sane and reasonable defaults. My understanding is that, if we ever wanted to, we should be able to get into lower-level code (plain pytorch) to fully tweak minute settings (ideally similar to how seaborn allows us to use a high level API, but still allows for plt low level code to tweak minutia).

We haven't encountered a single occasion that required us to do that 🎉 (so far)

Comparing Training speed¶

comparing training speed (1 epoch), depending on pretrained model size:

Size Model Time per Epoch comments
Small Resnet34 1m21 without optimizations: 2m18
Medium Resnet50 1m56
Large Resnet101 2m54 just a bit over the "small" if we configure it without mixed precision optimizations!
In [106]:
@cached_with_pickle(force=run_entire_notebook())
def learn_losses_train():
    return learn.recorder.losses


@cached_with_pickle(force=run_entire_notebook())
def learn_losses_validation():
    return learn.recorder.values
skipping optional operation
skipping optional operation
In [107]:
train_losses = learn_losses_train()
valid_losses = learn_losses_validation()
Loading from cache [./cached/pickle/learn_losses_train.pickle]
Loading from cache [./cached/pickle/learn_losses_validation.pickle]

Losses are not stored as part of the model, so we will store them separately in case we need them later

Assessing model training¶

In [108]:
@run
@cached_chart()
def chart_train_loss():
    plot = sns.lineplot([float(loss) for loss in train_losses], label="train loss")
    plt.title("Training Loss")
    plt.xlabel("batch")
    plt.ylabel("loss")
    plt.ylim(0)
    plt.xlim(0)
    return plot
Loading from cache [./cached/charts/chart_train_loss.png]
In [109]:
@run
@cached_chart()
def chart_val_loss():
    plot = sns.lineplot([float(loss[0]) for loss in valid_losses], label="val loss")
    plt.title("Validation loss")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.ylim(0)
    plt.xlim(0)
    return plot
Loading from cache [./cached/charts/chart_val_loss.png]

It seems that the 60 epochs config is enough to get a good balance of learning/performance without wasting resources (the chart seems to hint that we're reaching marginal gains after the 60 epochs, with its asymptotic behaviour)

Storing our model to disk¶

In [110]:
learn.model
Out[110]:
AgeGenderModel(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (5): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): BasicBlock(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (3): BasicBlock(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (6): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (3): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (4): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (5): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (7): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): BasicBlock(
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (age_head): Sequential(
    (0): AdaptiveConcatPool2d(
      (ap): AdaptiveAvgPool2d(output_size=1)
      (mp): AdaptiveMaxPool2d(output_size=1)
    )
    (1): fastai.layers.Flatten(full=False)
    (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.25, inplace=False)
    (4): Linear(in_features=1024, out_features=512, bias=False)
    (5): ReLU(inplace=True)
    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Dropout(p=0.5, inplace=False)
    (8): Linear(in_features=512, out_features=1, bias=False)
  )
  (gender_head): Sequential(
    (0): AdaptiveConcatPool2d(
      (ap): AdaptiveAvgPool2d(output_size=1)
      (mp): AdaptiveMaxPool2d(output_size=1)
    )
    (1): fastai.layers.Flatten(full=False)
    (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.25, inplace=False)
    (4): Linear(in_features=1024, out_features=512, bias=False)
    (5): ReLU(inplace=True)
    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Dropout(p=0.5, inplace=False)
    (8): Linear(in_features=512, out_features=2, bias=False)
  )
)

Visualizing the model's learning progress through tensorboard¶

(py39_lab4) edu@desk:~/turing/projects/sprint15-profiling/project$ tensorboard --logdir runs/
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.15.1 at http://localhost:6006/ (Press CTRL+C to quit)

Tensorboard has been a very useful tool to inspect and debug the model's learning (as well as to detect a bug in my code that resulted in data leakaged)

  • Observations
    • Top picture is during some of the rounds with overfitting
    • Bottom picture is during the last round of training with less epochs. The flat part at the end of the line is the predictions rounds, they are logged but the loss is not updated because the model is not training. It can be ignored and was only there for inspection/debug purposes

Assessing Performance¶

Assessing performance when predicting age¶

Context:

  • There are numerous studies assessing how humans perform when asked to guess someone's age (from passport photos).

The researchers found that overall, the volunteers were incorrect by an average of eight years in their estimates.

  • sources
    1. article "Study shows how bias can influence people estimating the ages of other people" - https://medicalxpress.com/news/2018-10-bias-people-ages.html
    2. study: "Age estimation via face images: a survey" - https://jivp-eurasipjournals.springeropen.com/articles/10.1186/s13640-018-0278-6

Outcome:

  • We will add some "acceptable error" margins ($\pm 8 \text{years}$) to allow the model the same leniency as the participants in these studies.

Assessing performance when predicting gender¶

Observation:

  • This dataset treats gender as a binary label.

Outcome:

  • For now, we will simplify our assessment of model performance as a simple binary classification problem.
In [112]:
def performance_for_split(splitname: str):
    """
    assess inference performance for the given data split.

    Performs inference and collect a summary dataframe for each of
        the inputs, with performance metrics

    :param splitname: name of the subdir under dataset/splits/{splitname}
    """
    tolerance_years = 8
    items = get_image_files(path / splitname)
    print(len(items))
    actuals = [get_labels_age_gender_ethnicity(item) for item in items]
    dl = dls.test_dl(items)
    preds, _ = learn.get_preds(dl=dl)
    display(preds)
    performance = pd.DataFrame(
        {
            "actual_age": [t[0] for t in actuals],
            "pred_age": preds[0][:, 0],
            "actual_gender": [t[1] for t in actuals],
            "pred_gender": torch.argmax(preds[1], dim=1),
            "filename": [str(p) for p in items],
            "ethnicity": [t[2] for t in actuals],
        }
    )
    performance["age_error"] = performance.pred_age - performance.actual_age
    performance["gender_error"] = performance.pred_gender - performance.actual_gender

    performance["age_correct"] = performance["age_error"].between(
        -tolerance_years, tolerance_years
    )
    performance["gender_correct"] = performance["gender_error"] == 0

    performance.attrs["splitname"] = splitname
    return performance
In [113]:
def plot_age_errors(performance_df: pd.DataFrame, ax=None):
    if ax is None:
        ax = plt.gca()

    title = f"age errors - {performance_df.attrs['splitname']} split"
    ax.set_title(title)
    # ax.set_xlim(-40, 40)
    return sns.histplot(
        data=performance_df,
        x="age_error",
        hue="age_correct",
        palette=["red", "green"],
        # bins=np.arange(-40, 40, 1),
        ax=ax,
    )


def plot_age_errors_color(
    performance_df: pd.DataFrame,
    ax=None,
    hue="age_correct",
    palette=["red", "green"],
):
    if ax is None:
        ax = plt.gca()

    title = f"age errors - {performance_df.attrs['splitname']} split"
    ax.set_title(title)
    # ax.set_xlim(-40, 40)
    return sns.histplot(
        data=performance_df,
        x="age_error",
        hue=hue,
        multiple="stack",
        palette=palette,
        # bins=np.arange(-40, 40, 1),
        ax=ax,
    )


def plot_age_scatter(performance_df: pd.DataFrame, ax=None, alpha=0.05):
    if ax is None:
        ax = plt.gca()

    title = f"age predictions - {performance_df.attrs['splitname']} split"
    ax.set_title(title)
    ax.axline((0, 0), (100, 100), color="green")

    return sns.scatterplot(
        data=performance_df,
        y="actual_age",
        x="pred_age",
        alpha=alpha,
        hue="age_correct",
        palette=["red", "green"],
        ax=ax,
    )


def plot_gender_errors(performance_df: pd.DataFrame, ax=None):
    if ax is None:
        ax = plt.gca()

    title = f"gender errors - {performance_df.attrs['splitname']} split"
    ax.set_title(title)
    return sns.countplot(
        data=performance_df,
        x="gender_error",
        hue="gender_correct",
        dodge=False,
        palette={True: "green", False: "red"},
        ax=ax,
    )
In [114]:
performance_train = performance_for_split("train")
performance_train
15291
(tensor([[24.3281],
         [ 3.6504],
         [29.9688],
         ...,
         [43.6875],
         [29.4844],
         [23.0156]]),
 tensor([[-1.8779,  1.8291],
         [ 0.0383, -0.0121],
         [-1.8516,  1.8262],
         ...,
         [ 4.3555, -4.6094],
         [-1.8867,  1.8916],
         [-3.0371,  2.9609]]))
Out[114]:
actual_age pred_age actual_gender pred_gender filename ethnicity age_error gender_error age_correct gender_correct
0 26.0 24.328125 1 1 dataset/splits/train/26_1_4_20170117154131789.... 4 -1.671875 0 True True
1 2.0 3.650391 1 0 dataset/splits/train/2_1_4_20161221203029673.j... 4 1.650391 -1 True False
2 30.0 29.968750 1 1 dataset/splits/train/30_1_0_20170109001620649.... 0 -0.031250 0 True True
3 25.0 23.843750 0 0 dataset/splits/train/25_0_0_20170120221436173.... 0 -1.156250 0 True True
4 10.0 11.773438 0 0 dataset/splits/train/10_0_4_20170103202338152.... 4 1.773438 0 True True
... ... ... ... ... ... ... ... ... ... ...
15286 55.0 49.843750 0 0 dataset/splits/train/55_0_0_20170104184424541.... 0 -5.156250 0 True True
15287 43.0 41.812500 0 0 dataset/splits/train/43_0_3_20170119181404861.... 3 -1.187500 0 True True
15288 35.0 43.687500 0 0 dataset/splits/train/35_0_0_20170104183852983.... 0 8.687500 0 False True
15289 35.0 29.484375 1 1 dataset/splits/train/35_1_0_20170117144916091.... 0 -5.515625 0 True True
15290 18.0 23.015625 1 1 dataset/splits/train/18_1_0_20170117140343665.... 0 5.015625 0 True True

15291 rows × 10 columns

In [115]:
performance_val = performance_for_split("val")
performance_val
3277
(tensor([[67.3125],
         [35.7500],
         [42.6250],
         ...,
         [30.5469],
         [23.5000],
         [39.2812]]),
 tensor([[ 3.0098, -3.1895],
         [-2.5020,  2.4551],
         [ 3.0449, -3.2480],
         ...,
         [-1.8887,  1.8955],
         [ 2.1738, -2.3848],
         [ 4.8281, -5.1016]]))
Out[115]:
actual_age pred_age actual_gender pred_gender filename ethnicity age_error gender_error age_correct gender_correct
0 61.0 67.312500 0 0 dataset/splits/val/61_0_3_20170119211856632.jp... 3 6.312500 0 True True
1 26.0 35.750000 1 1 dataset/splits/val/26_1_0_20170117153717556.jp... 0 9.750000 0 False True
2 26.0 42.625000 0 0 dataset/splits/val/26_0_0_20170117120944631.jp... 0 16.625000 0 False True
3 46.0 55.093750 1 1 dataset/splits/val/46_1_0_20170104184041597.jp... 0 9.093750 0 False True
4 37.0 37.875000 0 0 dataset/splits/val/37_0_0_20170119180034627.jp... 0 0.875000 0 True True
... ... ... ... ... ... ... ... ... ... ...
3272 72.0 65.187500 0 0 dataset/splits/val/72_0_2_20170105174444334.jp... 2 -6.812500 0 True True
3273 35.0 40.500000 0 0 dataset/splits/val/35_0_0_20170117182852603.jp... 0 5.500000 0 True True
3274 29.0 30.546875 1 1 dataset/splits/val/29_1_1_20170112204807283.jp... 1 1.546875 0 True True
3275 13.0 23.500000 0 0 dataset/splits/val/13_0_3_20170110232628896.jp... 3 10.500000 0 False True
3276 32.0 39.281250 0 0 dataset/splits/val/32_0_0_20170117203115358.jp... 0 7.281250 0 True True

3277 rows × 10 columns

In [116]:
performance_test = performance_for_split("test")
performance_test
3277
(tensor([[54.1562],
         [46.2188],
         [36.7500],
         ...,
         [42.4375],
         [12.6719],
         [76.3750]]),
 tensor([[ 4.3945, -4.6641],
         [ 2.9160, -3.0664],
         [-1.8604,  1.8027],
         ...,
         [ 3.3223, -3.5234],
         [-1.2158,  1.2852],
         [ 0.8242, -0.8218]]))
Out[116]:
actual_age pred_age actual_gender pred_gender filename ethnicity age_error gender_error age_correct gender_correct
0 44.0 54.156250 0 0 dataset/splits/test/44_0_3_20170119204704727.j... 3 10.156250 0 False True
1 48.0 46.218750 0 0 dataset/splits/test/48_0_0_20170109012109036.j... 0 -1.781250 0 True True
2 49.0 36.750000 1 1 dataset/splits/test/49_1_1_20170113000544753.j... 1 -12.250000 0 False True
3 61.0 73.500000 1 1 dataset/splits/test/61_1_0_20170120225333848.j... 0 12.500000 0 False True
4 1.0 2.515625 0 1 dataset/splits/test/1_0_0_20161219204552941.jp... 0 1.515625 1 True False
... ... ... ... ... ... ... ... ... ... ...
3272 23.0 25.703125 1 1 dataset/splits/test/23_1_0_20170117145019683.j... 0 2.703125 0 True True
3273 12.0 21.859375 0 1 dataset/splits/test/12_0_0_20170117165940524.j... 0 9.859375 1 False False
3274 46.0 42.437500 0 0 dataset/splits/test/46_0_0_20170104203049435.j... 0 -3.562500 0 True True
3275 29.0 12.671875 1 1 dataset/splits/test/29_1_1_20170114024736192.j... 1 -16.328125 0 False True
3276 75.0 76.375000 0 0 dataset/splits/test/75_0_3_20170111210912724.j... 3 1.375000 0 True True

3277 rows × 10 columns

In [117]:
@run
@cached_chart(force=run_entire_notebook())
def model_performance_all_splits():

    f, ax = plt.subplots(3, 3, figsize=(15, 12))

    plot_age_errors(performance_train, ax=ax[0, 0])
    plot_age_errors(performance_val, ax=ax[1, 0])
    plot_age_errors(performance_test, ax=ax[2, 0])
    ax[0, 0].set_ylim(0, 2500)
    ax[1, 0].set_ylim(0, 500)
    ax[2, 0].set_ylim(0, 500)

    plot_age_scatter(performance_train, ax=ax[0, 1], alpha=0.03)
    plot_age_scatter(performance_val, ax=ax[1, 1], alpha=0.1)
    plot_age_scatter(performance_test, ax=ax[2, 1], alpha=0.1)

    plot_gender_errors(performance_train, ax=ax[0, 2])
    plot_gender_errors(performance_val, ax=ax[1, 2])
    plot_gender_errors(performance_test, ax=ax[2, 2])

    plt.tight_layout()
    return plt.gcf()
skipping optional operation
Loading from cache [./cached/charts/model_performance_all_splits.png]

Let's zoom into the age predictions scatterplots

In [118]:
@run
@cached_chart(force=run_entire_notebook())
def model_performance_scatter_age():
    f, ax = plt.subplots(1, 2, figsize=(15, 8))

    plot_age_scatter(performance_train, ax=ax[0], alpha=0.03)
    plot_age_scatter(performance_test, ax=ax[1], alpha=0.1)

    plt.tight_layout()
    return plt.gcf()
skipping optional operation
Loading from cache [./cached/charts/model_performance_scatter_age.png]

Notes on reading these charts:

  • ⚠️ Do not use brightness/darknes of the coloured dots to assess performance or draw conclusions, the alpha for the trainig dataset has been lowered compared to the other splits.
    • This is done to better see clusters of overlapping dots (darker areas), but affects the brightness of the dots overall and might give the impression that the training dataset is performing better/worse than it is.

Some thoughts:

  • Age error prediction charts look as we expected:
    • a peak centered around 0 (no error) and long tails on both sides.
    • most values (91%) are within the +/- 8 years range
    • the scatterplot shows clear horizontal lines with horizontal blurring (due to predictions having decimals), and no vertical blurring (due to actual values being integers)
  • Gender error looks really good with minimal errors.
  • Good generalization
    • No signs of overfitting
    • No unexpected behaviour or differences between train/val/test splits
  • The Age error distribution could even seem to fit a Laplace/Double exponential distribution.. we could check later and see how well it fits. Just for fun.
    • We decided to fit this but it didn't contribute meaningfully, so it was scrapped, to keep the notebook concise.
    • Observations:
      • p-value = 3e-6 (the distribution did not really fit a standard laplace dist. even when inspected visually. the peak was not as pronounced as it should be, and it was mildly shifted towards the right side)

If we use the +/- 8 years range to classify age prediction performance, we get quite notable results.

In [119]:
performance_train["age_correct"].value_counts(normalize=True) * 100
Out[119]:
age_correct
True     88.692695
False    11.307305
Name: proportion, dtype: float64
In [120]:
performance_val["age_correct"].value_counts(normalize=True) * 100
Out[120]:
age_correct
True     80.958193
False    19.041807
Name: proportion, dtype: float64
In [121]:
performance_test["age_correct"].value_counts(normalize=True) * 100
Out[121]:
age_correct
True     87.610619
False    12.389381
Name: proportion, dtype: float64

Final model performance¶

In [122]:
@run
@cached_chart(force=run_entire_notebook())
def model_confusion_matrix_gender():
    ConfusionMatrixDisplay.from_predictions(
        performance_test["actual_gender"],
        performance_test["pred_gender"],
        normalize="true",
        cmap="Greys_r",
    )
    plt.grid(False)
    return plt.gcf()
skipping optional operation
Loading from cache [./cached/charts/model_confusion_matrix_gender.png]

Clustering analysis¶

This is not foolproof because the boundaries of the clusters are arbitrary (and a difference of 0.1 year could be marked as "error" if it happens to fall between 29.96 and 30.05 yrs, for example)... but it's going to be good enough to get a rough idea.

In [123]:
performance_test["actual_decade"] = performance_test["actual_age"] // 10
performance_test["pred_decade"] = performance_test["pred_age"] // 10
In [124]:
@run
@cached_chart(force=run_entire_notebook())
def model_confusion_matrix_age_decades():
    plt.figure(figsize=(12, 12))
    ConfusionMatrixDisplay.from_predictions(
        performance_test["actual_decade"],
        performance_test["pred_decade"],
        normalize="true",
        values_format=",.2f",
        cmap="Greys_r",
        ax=plt.gca(),
    )
    plt.grid(False)
    return plt.gcf()
skipping optional operation
Loading from cache [./cached/charts/model_confusion_matrix_age_decades.png]

Observations:

  • The model tends to skew towards underestimating age by a few years (if any)
  • The model performs poorly with ages above 80, likely due to the low number of training images.

Choice of pretrained model¶

We prepare and tested with different sizes of pretrained model (resnet34, 50, 101) but in the end, the smaller one gave us fairly good performance, so there was little need to train a more complex one.

Analysis of Model Bias and underlying issues¶

In this section we would like to explore a few things:

  • Model Bias
    • Does our model show any bias or unfair/unequal treatment across different slices of the population.
    • If so, are any of these somewhat similar to the biases in the original dataset?

Let's bring back the original chart showing the split based on each of the attributes

In [125]:
@run
@cached_chart(force=False)
def dataset_labels():
    raise NotImplementedError(
        "Cached chart not found in cache/ dir.\n"
        "Consider re-running the notebook from end to end, "
        "or download the entire repository"
    )
Loading from cache [./cached/charts/dataset_labels.png]

Recall that, despite some of these appearing balanced at first glance (gender, for example), once we start digging deeper we see that there are clear imbalances across multiple dimensions (white ethnicity has mostly male pics, while asian ethnicity has almost 2x the female pics than male).

In [126]:
@run
@cached_chart(force=False)
def population_breakdown():
    raise NotImplementedError(
        "Cached chart not found in cache/ dir.\n"
        "Consider re-running the notebook from end to end, "
        "or download the entire repository"
    )
Loading from cache [./cached/charts/population_breakdown.png]

Inspecting model performance across ethnicity¶

In [127]:
performance_test
Out[127]:
actual_age pred_age actual_gender pred_gender filename ethnicity age_error gender_error age_correct gender_correct actual_decade pred_decade
0 44.0 54.156250 0 0 dataset/splits/test/44_0_3_20170119204704727.j... 3 10.156250 0 False True 4.0 5.0
1 48.0 46.218750 0 0 dataset/splits/test/48_0_0_20170109012109036.j... 0 -1.781250 0 True True 4.0 4.0
2 49.0 36.750000 1 1 dataset/splits/test/49_1_1_20170113000544753.j... 1 -12.250000 0 False True 4.0 3.0
3 61.0 73.500000 1 1 dataset/splits/test/61_1_0_20170120225333848.j... 0 12.500000 0 False True 6.0 7.0
4 1.0 2.515625 0 1 dataset/splits/test/1_0_0_20161219204552941.jp... 0 1.515625 1 True False 0.0 0.0
... ... ... ... ... ... ... ... ... ... ... ... ...
3272 23.0 25.703125 1 1 dataset/splits/test/23_1_0_20170117145019683.j... 0 2.703125 0 True True 2.0 2.0
3273 12.0 21.859375 0 1 dataset/splits/test/12_0_0_20170117165940524.j... 0 9.859375 1 False False 1.0 2.0
3274 46.0 42.437500 0 0 dataset/splits/test/46_0_0_20170104203049435.j... 0 -3.562500 0 True True 4.0 4.0
3275 29.0 12.671875 1 1 dataset/splits/test/29_1_1_20170114024736192.j... 1 -16.328125 0 False True 2.0 1.0
3276 75.0 76.375000 0 0 dataset/splits/test/75_0_3_20170111210912724.j... 3 1.375000 0 True True 7.0 7.0

3277 rows × 12 columns

In [128]:
@run
@cached_chart(force=run_entire_notebook())
def model_performance_age_scatter_by_ethnicity_and_gender():
    f, ax = plt.subplots(2, 5, figsize=(20, 10))

    for col, ethnicity in enumerate(ethnicity_map):
        e_subset = performance_test[performance_test["ethnicity"] == ethnicity]
        for row, gender in enumerate(gender_map):
            g_subset = e_subset[e_subset["actual_gender"] == gender]
            plot_age_scatter(g_subset, ax=ax[row, col], alpha=0.1)
            ax[row, col].set_title(
                f"age predictions for {ethnicity_map[ethnicity]}, {gender_map[gender]}"
            )

    plt.tight_layout()
    return plt.gcf()
skipping optional operation
Loading from cache [./cached/charts/model_performance_age_scatter_by_ethnicity_and_gender.png]

Observations:

  • Most errors seem to be spread equally across both sides of the line (+/- 8yrs), but for the indian, male cluster, the model seems to have a strong tendency to consistently predict either "baby" or "40 yrs old" since we don't see the errors spread equally, but they exhibit a vertical clustering around x=40.
    • This could also just be a pareidolia caused by the low amount of data, but the same pattern appears if we plot the errors for the training split instread of the test one.
In [129]:
@run
@cached_chart(force=run_entire_notebook())
def model_performance_age_errors_by_ethnicity_and_gender():
    f, ax = plt.subplots(2, 5, figsize=(18, 8))

    for col, ethnicity in enumerate(ethnicity_map):
        e_subset = performance_test[performance_test["ethnicity"] == ethnicity]
        for row, gender in enumerate(gender_map):
            g_subset = e_subset[e_subset["actual_gender"] == gender]
            plot_age_errors(g_subset, ax=ax[row, col])
            ax[row, col].set_title(
                f"age errors for {ethnicity_map[ethnicity]}, {gender_map[gender]}"
            )
            # ax[row, col].set_ylim(0, 10)

    plt.tight_layout()
    return plt.gcf()
skipping optional operation
Loading from cache [./cached/charts/model_performance_age_errors_by_ethnicity_and_gender.png]

Observations:

  • Regarding differences across ethnicities: We see again that the group with the most errors is ethnicity=white, likely because it's the only subset of pictures that has examples for the entire age range. Other ethnicities tend to have either 1yr old babies, or young adults.
    • Considering that color plays a singificant role, this information could be being used by the model to determine wether to perform deep inference (if white), or to predict "young adult, 25yr" and go for a safe bet, since this is where most of the non-white people are clustered around.
  • Regarding differences across genders: We see that groups who had more examples of female pics (asian, and indian), the model performs less errors for those groups where there were more training pictures.
In [130]:
@run
@cached_chart(force=run_entire_notebook(value_only=True))
def model_performance_gender_by_ethnicity():
    f, ax = plt.subplots(2, 5, figsize=(20, 6))

    for col, ethnicity in enumerate(ethnicity_map):
        e_subset = performance_test[performance_test["ethnicity"] == ethnicity]
        ConfusionMatrixDisplay.from_predictions(
            e_subset["actual_gender"],
            e_subset["pred_gender"],
            display_labels=gender_map.values(),
            cmap="Greys_r",
            ax=ax[0, col],
        )
        ConfusionMatrixDisplay.from_predictions(
            e_subset["actual_gender"],
            e_subset["pred_gender"],
            normalize="true",
            values_format=",.2%",
            display_labels=gender_map.values(),
            cmap="Greys_r",
            ax=ax[1, col],
        )

        ax[0, col].grid(False)
        ax[1, col].grid(False)
        ax[0, col].set_title(f"gender predictions for {ethnicity_map[ethnicity]}")
        ax[1, col].set_title(
            f"gender predictions for {ethnicity_map[ethnicity]} - normalized"
        )

    plt.grid(False)
    plt.tight_layout()
    return plt.gcf()
Loading from cache [./cached/charts/model_performance_gender_by_ethnicity.png]

Analysing errors¶

Age Errors¶

In [131]:
top_x_age_errors = 100
top_age_errors = pd.concat(
    [
        performance_test.sort_values(by="age_error")[:top_x_age_errors],
        performance_test.sort_values(by="age_error")[-top_x_age_errors:],
    ]
)
top_age_errors.head(20)
Out[131]:
actual_age pred_age actual_gender pred_gender filename ethnicity age_error gender_error age_correct gender_correct actual_decade pred_decade
2545 70.0 30.875000 0 0 dataset/splits/test/70_0_0_20170111200757701.j... 0 -39.125000 0 False True 7.0 3.0
1929 116.0 81.500000 1 1 dataset/splits/test/116_1_0_20170120134921760.... 0 -34.500000 0 False True 11.0 8.0
1023 77.0 47.656250 0 0 dataset/splits/test/77_0_1_20170116210256280.j... 1 -29.343750 0 False True 7.0 4.0
1964 57.0 31.000000 0 0 dataset/splits/test/57_0_0_20170117172532619.j... 0 -26.000000 0 False True 5.0 3.0
1141 116.0 92.937500 1 1 dataset/splits/test/116_1_2_20170112220255503.... 2 -23.062500 0 False True 11.0 9.0
3056 86.0 62.968750 1 1 dataset/splits/test/86_1_0_20170120225525242.j... 0 -23.031250 0 False True 8.0 6.0
1956 90.0 68.812500 0 0 dataset/splits/test/90_0_0_20170120230038954.j... 0 -21.187500 0 False True 9.0 6.0
1157 54.0 33.343750 1 1 dataset/splits/test/54_1_0_20170117171505517.j... 0 -20.656250 0 False True 5.0 3.0
1974 54.0 33.406250 0 1 dataset/splits/test/54_0_0_20170113210127075.j... 0 -20.593750 1 False False 5.0 3.0
2737 80.0 59.656250 0 0 dataset/splits/test/80_0_0_20170117173234032.j... 0 -20.343750 0 False True 8.0 5.0
912 65.0 45.968750 0 0 dataset/splits/test/65_0_1_20170113145609182.j... 1 -19.031250 0 False True 6.0 4.0
2685 61.0 42.000000 1 1 dataset/splits/test/61_1_0_20170117174551886.j... 0 -19.000000 0 False True 6.0 4.0
2302 67.0 48.218750 0 0 dataset/splits/test/67_0_0_20170113210319928.j... 0 -18.781250 0 False True 6.0 4.0
2404 49.0 30.468750 0 0 dataset/splits/test/49_0_3_20170119205458583.j... 3 -18.531250 0 False True 4.0 3.0
301 50.0 31.515625 1 1 dataset/splits/test/50_1_0_20170105162633419.j... 0 -18.484375 0 False True 5.0 3.0
2486 58.0 39.781250 1 1 dataset/splits/test/58_1_1_20170113012224304.j... 1 -18.218750 0 False True 5.0 3.0
1543 70.0 51.875000 1 1 dataset/splits/test/70_1_0_20170120134300289.j... 0 -18.125000 0 False True 7.0 5.0
2386 89.0 71.062500 0 0 dataset/splits/test/89_0_1_20170117182437361.j... 1 -17.937500 0 False True 8.0 7.0
1298 51.0 33.281250 0 0 dataset/splits/test/51_0_1_20170113142040362.j... 1 -17.718750 0 False True 5.0 3.0
864 70.0 52.750000 1 1 dataset/splits/test/70_1_0_20170117163559185.j... 0 -17.250000 0 False True 7.0 5.0
In [132]:
plot_age_errors_color(
    top_age_errors.sort_values(by="ethnicity", ascending=False),
    hue="ethnicity",
    palette="viridis",
)
Out[132]:
<AxesSubplot: title={'center': 'age errors - test split'}, xlabel='age_error', ylabel='Count'>

No major differences around predictions, based on ethnicity label

In [133]:
fig, ax = plt.subplots(1, 5, figsize=(16, 4))

for i, (e, ethnicity) in enumerate(ethnicity_map.items()):
    cax = ax[i]
    top_age_errors_ethnic = top_age_errors[top_age_errors["ethnicity"] == e]

    plot_age_scatter(top_age_errors_ethnic, ax=cax, alpha=0.5)

plt.tight_layout()
plt.show()
/tmp/ipykernel_15851/1151474415.py:49: UserWarning: The palette list has more values (2) than needed (1), which may not be intended.
  return sns.scatterplot(
/tmp/ipykernel_15851/1151474415.py:49: UserWarning: The palette list has more values (2) than needed (1), which may not be intended.
  return sns.scatterplot(
/tmp/ipykernel_15851/1151474415.py:49: UserWarning: The palette list has more values (2) than needed (1), which may not be intended.
  return sns.scatterplot(
/tmp/ipykernel_15851/1151474415.py:49: UserWarning: The palette list has more values (2) than needed (1), which may not be intended.
  return sns.scatterplot(
/tmp/ipykernel_15851/1151474415.py:49: UserWarning: The palette list has more values (2) than needed (1), which may not be intended.
  return sns.scatterplot(

Nothing new compared to what we already saw earlier.

Model Explainability¶

Since LIME is model-agnostic, we need to make sure that we bridge the gap between this library and our custom model.

This means, we must provide the data in the same way that our model transforms and pre-processes during training and prediction.

Recall that our preprocessing pipeline uses no data augmentation (for now) and that it's quite simple: resizing and cropping.

In [311]:
explain_files = glob.glob("dataset/explain/*.jpg")

Let's try with the pics that got the worst errors

In [318]:
age_underestimate = top_age_errors[:5]
age_overestimate = top_age_errors[-5:]
In [319]:
age_underestimate
Out[319]:
actual_age pred_age actual_gender pred_gender filename ethnicity age_error gender_error age_correct gender_correct actual_decade pred_decade
2545 70.0 30.87500 0 0 dataset/splits/test/70_0_0_20170111200757701.j... 0 -39.12500 0 False True 7.0 3.0
1929 116.0 81.50000 1 1 dataset/splits/test/116_1_0_20170120134921760.... 0 -34.50000 0 False True 11.0 8.0
1023 77.0 47.65625 0 0 dataset/splits/test/77_0_1_20170116210256280.j... 1 -29.34375 0 False True 7.0 4.0
1964 57.0 31.00000 0 0 dataset/splits/test/57_0_0_20170117172532619.j... 0 -26.00000 0 False True 5.0 3.0
1141 116.0 92.93750 1 1 dataset/splits/test/116_1_2_20170112220255503.... 2 -23.06250 0 False True 11.0 9.0
In [320]:
age_overestimate
Out[320]:
actual_age pred_age actual_gender pred_gender filename ethnicity age_error gender_error age_correct gender_correct actual_decade pred_decade
575 27.0 44.437500 0 0 dataset/splits/test/27_0_2_20170119193329569.j... 2 17.437500 0 False True 2.0 4.0
565 40.0 62.312500 0 0 dataset/splits/test/40_0_0_20170113210319647.j... 0 22.312500 0 False True 4.0 6.0
397 36.0 58.531250 0 0 dataset/splits/test/36_0_0_20170113210318892.j... 0 22.531250 0 False True 3.0 5.0
776 1.0 25.328125 0 0 dataset/splits/test/1_0_3_20170104230640081.jp... 3 24.328125 0 False True 0.0 2.0
1237 1.0 28.125000 0 1 dataset/splits/test/1_0_4_20161221193016140.jp... 4 27.125000 1 False False 0.0 2.0

Checking the "Age Underestimated" errors¶

In [352]:
@run
@cached_chart()
def explain_image_underestimate_0_age():
    return lime_explain_pics.lime_mask_explain(
        learn,
        age_underestimate.iloc[0].filename,
        "age",
        num_samples=100,
    )
  0%|          | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 14 to 15 (conf = 0.00)'}>
In [353]:
@run
@cached_chart()
def explain_image_underestimate_1_age():
    return lime_explain_pics.lime_mask_explain(
        learn,
        age_underestimate.iloc[1].filename,
        "age",
        num_samples=100,
    )
  0%|          | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 5 to 6 (conf = 0.00)'}>
In [354]:
@run
@cached_chart()
def explain_image_underestimate_2_age():
    return lime_explain_pics.lime_mask_explain(
        learn,
        age_underestimate.iloc[2].filename,
        "age",
        num_samples=100,
    )
  0%|          | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 5 to 6 (conf = 0.00)'}>
In [355]:
@run
@cached_chart()
def explain_image_underestimate_3_age():
    return lime_explain_pics.lime_mask_explain(
        learn,
        age_underestimate.iloc[3].filename,
        "age",
        num_samples=100,
    )
  0%|          | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 14 to 15 (conf = 0.55)'}>
In [356]:
@run
@cached_chart()
def explain_image_underestimate_4_age():
    return lime_explain_pics.lime_mask_explain(
        learn,
        age_underestimate.iloc[4].filename,
        "age",
        num_samples=100,
    )
  0%|          | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 14 to 15 (conf = 0.00)'}>

Observations:

  • We have checked the 5 worst examples of underestimation.
  • The model makes predictions that I would easily make. I suspect that some of these pictures are mislabeled
  • The last one might not be mislabeled, but I would easily make the same prediction that the model makes

Outcomes:

  • No concerns around issues with the model

Checking the "Age Overestimated" errors¶

In [357]:
@run
@cached_chart()
def explain_image_overestimate_0_age():
    return lime_explain_pics.lime_mask_explain(
        learn,
        age_overestimate.iloc[0].filename,
        "age",
        num_samples=100,
    )
  0%|          | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 5 to 6 (conf = 0.00)'}>
In [358]:
@run
@cached_chart()
def explain_image_overestimate_1_age():
    return lime_explain_pics.lime_mask_explain(
        learn,
        age_overestimate.iloc[1].filename,
        "age",
        num_samples=100,
    )
  0%|          | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 5 to 6 (conf = 0.00)'}>
In [360]:
@run
@cached_chart()
def explain_image_overestimate_2_age():
    return lime_explain_pics.lime_mask_explain(
        learn,
        age_overestimate.iloc[2].filename,
        "age",
        num_samples=100,
    )
  0%|          | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 5 to 6 (conf = 0.00)'}>
In [361]:
@run
@cached_chart()
def explain_image_overestimate_3_age():
    return lime_explain_pics.lime_mask_explain(
        learn,
        age_overestimate.iloc[3].filename,
        "age",
        num_samples=100,
    )
  0%|          | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 14 to 15 (conf = 0.00)'}>
In [362]:
@run
@cached_chart()
def explain_image_overestimate_4_age():
    return lime_explain_pics.lime_mask_explain(
        learn,
        age_overestimate.iloc[4].filename,
        "age",
        num_samples=100,
    )
  0%|          | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 14 to 15 (conf = 0.00)'}>

Observations:

  • We have checked the 5 worst examples of overestimation.
  • The model makes predictions that I would easily make.
  • Some of these pictures are clearly mislabeled:
    • last 2 are a good example of people who are not 1 year old babies!
    • 3rd picture does not look like someone who is 36... They seem to be 70yrs old

Outcomes:

  • In future work, we should re-inspect the training material, use LIME to find pictures that are mislabeled and remove them from the training dataset.

Summary¶

In this project we have created a custom fast ai model that uses a pretrained neural network to perform compound predictions using computer vision.

The pretrained model is based on ResNet34 and uses the weights from IMAGENET1K_V1.

We decided to use the smallest version of this pretrained model as it performed well enough, but if more performance was required, we expect larger and more complex pretrained models to perform slightly better.

Executive Summary¶

Performance¶

Age Prediction¶

The performance of this model, for age prediction, exceeds standards of average human performance (+/- 8 yrs) for age prediction for the vast majority of cases (over 80% of predictions within this +/-8 yrs range)

Gender Prediction¶

The performance of this model, for gender prediction, is almost perfect, achieving over 98% of correct matches in all categories (combinations of gender and ethnicities).

Analysis of BIAS¶

About gender vs gender expression.

This dataset required using computer vision to assess and predict "gender" based on a collection of faces from various people. Since the only data we have about them, the only thing we could actually be assessing is their gender expression. These differences in terminology matter as it helps have a more nuanced conversation about what this dataset actually includes.

Using the Genderbread Person diagram, we can see that we are only able to assess and predict the yellow dials related to gender expression and nothing beyond that.

About the model's bias¶

We see that the model performs better for the majority classes/clusters:

  • Young ( under 40s)
  • White
  • Male

We see that while other classes also have errors, the small amount of data (after slicing based on gender, age buckets, and ethnicity) we end up with similar results and no major problematic areas.

We could have used a more stringent criteria for "correct" guesses around age, allowing a smaller error for your people and a larger error for elder groups but we already see that the predictions are within those tolerances, so we use a simpler criteria to get the point across easily. We don't expect the performance metrics to change significantly if we had used a more triangular range (instead of 2 parallel bars at +/-8 yrs)

About the Dataset¶

About dataset cleanliness¶

This dataset proved to have abundant duplication of data, which required extensive and intensive cleaning to remove duplicated, partially duplicated, as well as "similar-enough" images.

This proved tedious and expensive. The proposed solution is able to run 200 million comparisons in under a minute, but this took a substantial amount of time to plan, prepare and assess the various performance/time tradeoffs.

This was crucial to get right as otherwise the dataset would suffer from extreme data leakage which would result in the model performing well even with little training, just through memorization.

About gender¶

💛💛💛💛💛💛💛💛💛
🤍🤍🤍🤍🤍🤍🤍🤍🤍
💜💜💜💜💜💜💜💜💜
🖤🖤🖤🖤🖤🖤🖤🖤🖤

In this TC project, we were required to use this kaggle dataset: https://www.kaggle.com/datasets/jangedoo/utkface-new

This dataset presents gender as a binary. However, it's crucial to understand that this binary representation is a simplification and doesn't reflect the complexity of gender.

Gender is not binary in any aspect - be it physical, physiological, hormonal, in terms of gender expression, or gender identity.

This binary labeling was used due to the constraints of the project, not as an endorsement of a binary view of gender.

From the perspective of queer theory and LGBT rights, this binary representation can be problematic. It overlooks the experiences and identities of those who don't fit within this binary, including transgender, non-binary, and genderqueer individuals. This can lead to erasure and marginalization, reinforcing harmful stereotypes and discrimination.

Therefore, while our project used a binary gender label, we acknowledge its limitations and advocate for more inclusive and nuanced representations of gender in data. We believe in the importance of recognizing and respecting all gender identities and expressions, as a fundamental aspect of human rights and dignity.

About this dataset being required for this project¶

This dataset, with its binary representation of gender, is far from ideal. It inadvertently perpetuates a narrow and oversimplified perception of gender among students. By presenting gender as a binary, it may lead students to internalize this limited understanding, thereby reinforcing the societal norms that queer theory and LGBTQIA+ rights movements challenge. This can hinder the development of a more inclusive, diverse, and accurate understanding of gender. It’s essential to critically engage with such datasets and question the assumptions they make, to foster a more comprehensive and respectful understanding of gender diversity. We must strive for datasets that reflect the reality of human experience, in all its richness and diversity.

About the Model and its performance¶

The model seemed somewhat brittle to overfitting, showing strange behaviour if it got trained for over 30 epocs:

This appeared as some wildly incorrect predictions (even in the training dataset):

These also appeared in the val dataset.

Interesting to see that this bias affected only pictures of males.

This was fixed by reducing the number of epochs to train for.

Future Improvements¶

  • Consider using larger models pretrained models if we want to improve performance
  • Consider retraining after further data cleaning:
    • Remove pictures that are clearly not pictures of people
    • Remove mislabeled pictures
    • These steps are easy now that we have a model that can predict and give measures around prediction confidence